From 9e6adcbca00e8d67c2fef559b1523f19c3242348 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 9 Apr 2025 14:17:47 +0200 Subject: [PATCH] aco: select fp32 to float8 conversions Reviewed-by: Rhys Perry Part-of: --- src/amd/compiler/aco_ir.cpp | 3 ++ .../instruction_selection/aco_isel_setup.cpp | 4 ++ .../aco_select_nir_alu.cpp | 41 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index 4451599a1ab..0b906e0dab3 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -583,6 +583,9 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx) case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2; case aco_opcode::v_interp_p2_f16_f32_inreg: case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0; + case aco_opcode::v_cvt_pk_fp8_f32: + case aco_opcode::p_v_cvt_pk_fp8_f32_ovfl: + case aco_opcode::v_cvt_pk_bf8_f32: return idx == -1; default: return gfx_level >= GFX11 && (get_gfx11_true16_mask(op) & BITFIELD_BIT(idx == -1 ? 3 : idx)); } diff --git a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp index 78f2eb54b87..4368698ff32 100644 --- a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp +++ b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp @@ -415,6 +415,10 @@ init_context(isel_context* ctx, nir_shader* shader) regclasses[alu_instr->src[0].src.ssa->index].type() == RegType::vgpr) type = RegType::vgpr; break; + case nir_op_f2e4m3fn: + case nir_op_f2e4m3fn_sat: + case nir_op_f2e5m2: + case nir_op_f2e5m2_sat: case nir_op_fmulz: case nir_op_ffmaz: case nir_op_f2f64: diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp index 37302231cdd..7007c4b0118 100644 --- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp @@ -2553,6 +2553,47 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src); break; } + case nir_op_f2e4m3fn: + case nir_op_f2e4m3fn_sat: + case nir_op_f2e5m2: + case nir_op_f2e5m2_sat: { + Operand src[2]; + if (instr->def.num_components == 2) { + Temp pk_src = get_ssa_temp(ctx, instr->src[0].src.ssa); + RegClass rc = RegClass(pk_src.regClass().type(), 1); + for (unsigned i = 0; i < 2; i++) + src[i] = Operand(emit_extract_vector(ctx, pk_src, instr->src[0].swizzle[i], rc)); + } else { + assert(instr->def.num_components == 1); + src[0] = Operand(get_alu_src(ctx, instr->src[0])); + src[1] = Operand::c32(0); + } + + /* Ideally we would want to use FP16_OVFL for the sat variants, + * but the ISA doc is wrong and Inf isn't clamped to max_float. + */ + bool clamp = instr->op == nir_op_f2e4m3fn_sat || instr->op == nir_op_f2e5m2_sat; + if (clamp) { + Temp max_float = bld.copy( + bld.def(s1), Operand::c32(fui(instr->op == nir_op_f2e4m3fn_sat ? 448.0f : 57344.0f))); + + for (unsigned i = 0; i < instr->def.num_components; i++) { + /* use minimum variant because it preserves NaN. */ + Instruction* clamped = bld.vop3(aco_opcode::v_minimummaximum_f32, bld.def(v1), src[i], + max_float, max_float); + clamped->valu().neg[2] = true; + src[i] = Operand(clamped->definitions[0].getTemp()); + } + } + + aco_opcode opcode = instr->op == nir_op_f2e4m3fn || instr->op == nir_op_f2e4m3fn_sat + ? aco_opcode::v_cvt_pk_fp8_f32 + : aco_opcode::v_cvt_pk_bf8_f32; + bld.vop3(opcode, Definition(dst), src[0], src[1]); + if (instr->def.num_components == 2) + emit_split_vector(ctx, dst, 2); + break; + } case nir_op_i2f16: { Temp src = get_alu_src(ctx, instr->src[0]); const unsigned input_size = instr->src[0].src.ssa->bit_size;