aco: select fp32 to float8 conversions

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35434>
This commit is contained in:
Georg Lehmann 2025-04-09 14:17:47 +02:00 committed by Marge Bot
parent 3a45802514
commit 9e6adcbca0
3 changed files with 48 additions and 0 deletions

View file

@ -583,6 +583,9 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2;
case aco_opcode::v_interp_p2_f16_f32_inreg:
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0;
case aco_opcode::v_cvt_pk_fp8_f32:
case aco_opcode::p_v_cvt_pk_fp8_f32_ovfl:
case aco_opcode::v_cvt_pk_bf8_f32: return idx == -1;
default:
return gfx_level >= GFX11 && (get_gfx11_true16_mask(op) & BITFIELD_BIT(idx == -1 ? 3 : idx));
}

View file

@ -415,6 +415,10 @@ init_context(isel_context* ctx, nir_shader* shader)
regclasses[alu_instr->src[0].src.ssa->index].type() == RegType::vgpr)
type = RegType::vgpr;
break;
case nir_op_f2e4m3fn:
case nir_op_f2e4m3fn_sat:
case nir_op_f2e5m2:
case nir_op_f2e5m2_sat:
case nir_op_fmulz:
case nir_op_ffmaz:
case nir_op_f2f64:

View file

@ -2553,6 +2553,47 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src);
break;
}
case nir_op_f2e4m3fn:
case nir_op_f2e4m3fn_sat:
case nir_op_f2e5m2:
case nir_op_f2e5m2_sat: {
Operand src[2];
if (instr->def.num_components == 2) {
Temp pk_src = get_ssa_temp(ctx, instr->src[0].src.ssa);
RegClass rc = RegClass(pk_src.regClass().type(), 1);
for (unsigned i = 0; i < 2; i++)
src[i] = Operand(emit_extract_vector(ctx, pk_src, instr->src[0].swizzle[i], rc));
} else {
assert(instr->def.num_components == 1);
src[0] = Operand(get_alu_src(ctx, instr->src[0]));
src[1] = Operand::c32(0);
}
/* Ideally we would want to use FP16_OVFL for the sat variants,
* but the ISA doc is wrong and Inf isn't clamped to max_float.
*/
bool clamp = instr->op == nir_op_f2e4m3fn_sat || instr->op == nir_op_f2e5m2_sat;
if (clamp) {
Temp max_float = bld.copy(
bld.def(s1), Operand::c32(fui(instr->op == nir_op_f2e4m3fn_sat ? 448.0f : 57344.0f)));
for (unsigned i = 0; i < instr->def.num_components; i++) {
/* use minimum variant because it preserves NaN. */
Instruction* clamped = bld.vop3(aco_opcode::v_minimummaximum_f32, bld.def(v1), src[i],
max_float, max_float);
clamped->valu().neg[2] = true;
src[i] = Operand(clamped->definitions[0].getTemp());
}
}
aco_opcode opcode = instr->op == nir_op_f2e4m3fn || instr->op == nir_op_f2e4m3fn_sat
? aco_opcode::v_cvt_pk_fp8_f32
: aco_opcode::v_cvt_pk_bf8_f32;
bld.vop3(opcode, Definition(dst), src[0], src[1]);
if (instr->def.num_components == 2)
emit_split_vector(ctx, dst, 2);
break;
}
case nir_op_i2f16: {
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;