nir,aco: optimize FP16_OFVL pattern created by vkd3d-proton

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35434>
This commit is contained in:
Georg Lehmann 2025-06-10 15:10:02 +02:00 committed by Marge Bot
parent 9e6adcbca0
commit f047a67fba
6 changed files with 34 additions and 1 deletions

View file

@ -88,6 +88,7 @@ void ac_nir_set_options(struct radeon_info *info, bool use_llvm,
options->has_msad = true;
options->has_shfr32 = true;
options->has_mul24_relaxed = true;
options->has_f2e4m3fn_satfn = !use_llvm && info->gfx_level >= GFX12;
options->lower_int64_options = nir_lower_imul64 | nir_lower_imul_high64 | nir_lower_imul_2x32_64 | nir_lower_divmod64 |
nir_lower_minmax64 | nir_lower_iabs64 | nir_lower_iadd_sat64 | nir_lower_conv64 |
nir_lower_bitfield_extract64;

View file

@ -417,6 +417,7 @@ init_context(isel_context* ctx, nir_shader* shader)
break;
case nir_op_f2e4m3fn:
case nir_op_f2e4m3fn_sat:
case nir_op_f2e4m3fn_satfn:
case nir_op_f2e5m2:
case nir_op_f2e5m2_sat:
case nir_op_fmulz:

View file

@ -2555,6 +2555,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
case nir_op_f2e4m3fn:
case nir_op_f2e4m3fn_sat:
case nir_op_f2e4m3fn_satfn:
case nir_op_f2e5m2:
case nir_op_f2e5m2_sat: {
Operand src[2];
@ -2588,7 +2589,8 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
aco_opcode opcode = instr->op == nir_op_f2e4m3fn || instr->op == nir_op_f2e4m3fn_sat
? aco_opcode::v_cvt_pk_fp8_f32
: aco_opcode::v_cvt_pk_bf8_f32;
: instr->op == nir_op_f2e4m3fn_satfn ? aco_opcode::p_v_cvt_pk_fp8_f32_ovfl
: aco_opcode::v_cvt_pk_bf8_f32;
bld.vop3(opcode, Definition(dst), src[0], src[1]);
if (instr->def.num_components == 2)
emit_split_vector(ctx, dst, 2);

View file

@ -1770,6 +1770,8 @@ opcode("bfdot2_bfadd", 1, tint16, [2, 2, 1], [tint16, tint16, tint16],
unop_numeric_convert("e4m3fn2f", tfloat32, tuint8, "_mesa_e4m3fn_to_float(src0)")
unop_numeric_convert("f2e4m3fn", tuint8, tfloat32, "_mesa_float_to_e4m3fn(src0)")
unop_numeric_convert("f2e4m3fn_sat", tuint8, tfloat32, "_mesa_float_to_e4m3fn_sat(src0)")
# AMD specific conversion that clamps finite values but not inf (GFX12 FP16_OVFL=1 behavior)
unop_numeric_convert("f2e4m3fn_satfn", tuint8, tfloat32, "isinf(src0) ? 0x7f : _mesa_float_to_e4m3fn_sat(src0)")
unop_numeric_convert("e5m22f", tfloat32, tuint8, "_mesa_e5m2_to_float(src0)")
unop_numeric_convert("f2e5m2", tuint8, tfloat32, "_mesa_float_to_e5m2(src0)")

View file

@ -3144,6 +3144,30 @@ optimizations += [
(('iadd', ('msad_4x8', a, b, 0), c), ('msad_4x8', a, b, c)),
]
# VKD3D-Proton patterns for FP16_OVFL=1 conversion to e4m3fn
def vkd3d_proton_f2e4m3_ovfl(variant, x, nan):
if variant == 0:
cond = ('feq', ('fabs', x), float('inf'))
elif variant == 1:
cond = ('feq', f'{x}(is_not_negative)', float('inf'))
elif variant == 2:
cond = ('feq', f'{x}(is_not_positive)', -float('inf'))
return ('bcsel', cond, f'#{nan}(is_nan)', x)
for var in range(3):
optimizations += [
(('f2e4m3fn_sat', vkd3d_proton_f2e4m3_ovfl(var, a, b)),
('f2e4m3fn_satfn', a), 'options->has_f2e4m3fn_satfn'),
]
for var0, var1 in itertools.product(range(3), repeat=2):
optimizations += [
(('f2e4m3fn_sat', ('vec2', vkd3d_proton_f2e4m3_ovfl(var0, a, b),
vkd3d_proton_f2e4m3_ovfl(var1, c, d))),
('f2e4m3fn_satfn', ('vec2', a, c)), 'options->has_f2e4m3fn_satfn'),
]
# "all_equal(eq(a, b), vec(~0))" is the same as "all_equal(a, b)"
# "any_nequal(neq(a, b), vec(0))" is the same as "any_nequal(a, b)"

View file

@ -638,6 +638,9 @@ typedef struct nir_shader_compiler_options {
/** Backend support msad_u4x8. */
bool has_msad;
/** Backend supports f2e4m3fn_satfn */
bool has_f2e4m3fn_satfn;
/**
* Is this the Intel vec4 backend?
*