From a90d4d340c3e9be4cd0c6fe4d5fcd15fa383a13b Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 21 Sep 2023 20:38:00 +0200 Subject: [PATCH] aco/gfx11.5: select SALU float conversions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewed-by: Daniel Schürmann Part-of: --- .../compiler/aco_instruction_selection.cpp | 155 ++++++++++++------ .../aco_instruction_selection_setup.cpp | 22 +-- 2 files changed, 115 insertions(+), 62 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 1b73bd58110..d7af93236c3 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1345,12 +1345,16 @@ emit_vec2_f2f16(isel_context* ctx, nir_alu_instr* instr, Temp dst) Temp src0 = emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc); Temp src1 = emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc); - src1 = as_vgpr(ctx, src1); - if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9) - bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src0, src1); - else - bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1); - emit_split_vector(ctx, dst, 2); + if (dst.regClass() == s1) { + bld.sop2(aco_opcode::s_cvt_pk_rtz_f16_f32, Definition(dst), src0, src1); + } else { + src1 = as_vgpr(ctx, src1); + if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9) + bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src0, src1); + else + bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1); + emit_split_vector(ctx, dst, 2); + } } void @@ -2929,13 +2933,20 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->op == nir_op_f2f16_rtne && ctx->block->fp_mode.round16_64 != fp_round_ne) + if (instr->op == nir_op_f2f16_rtne && ctx->block->fp_mode.round16_64 != fp_round_ne) { /* We emit s_round_mode/s_setreg_imm32 in lower_to_hw_instr to * keep value numbering and the scheduler simpler. */ - bld.vop1(aco_opcode::p_v_cvt_f16_f32_rtne, Definition(dst), src); - else - bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + if (dst.regClass() == v2b) + bld.vop1(aco_opcode::p_v_cvt_f16_f32_rtne, Definition(dst), src); + else + bld.sop1(aco_opcode::p_s_cvt_f16_f32_rtne, Definition(dst), src); + } else { + if (dst.regClass() == v2b) + bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + else + bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src); + } break; } case nir_op_f2f16_rtz: { @@ -2945,16 +2956,26 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } Temp src = get_alu_src(ctx, instr->src[0]); - if (ctx->block->fp_mode.round16_64 == fp_round_tz) - bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); - else if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9) + if (ctx->block->fp_mode.round16_64 == fp_round_tz) { + if (dst.regClass() == v2b) + bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + else + bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src); + } else if (dst.regClass() == s1) { + bld.sop2(aco_opcode::s_cvt_pk_rtz_f16_f32, Definition(dst), src, Operand::zero()); + } else if (ctx->program->gfx_level == GFX8 || ctx->program->gfx_level == GFX9) { bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32_e64, Definition(dst), src, Operand::zero()); - else + } else { bld.vop2(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src, as_vgpr(ctx, src)); + } break; } case nir_op_f2f32: { - if (instr->src[0].src.ssa->bit_size == 16) { + if (dst.regClass() == s1) { + assert(instr->src[0].src.ssa->bit_size == 16); + Temp src = get_alu_src(ctx, instr->src[0]); + bld.sop1(aco_opcode::s_cvt_f32_f16, Definition(dst), src); + } else if (instr->src[0].src.ssa->bit_size == 16) { emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst); } else if (instr->src[0].src.ssa->bit_size == 64) { emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst); @@ -2970,27 +2991,36 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_i2f16: { - assert(dst.regClass() == v2b); Temp src = get_alu_src(ctx, instr->src[0]); const unsigned input_size = instr->src[0].src.ssa->bit_size; - if (input_size <= 16) { - /* Expand integer to the size expected by the uint→float converter used below */ - unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32); - if (input_size != target_size) { - src = convert_int(ctx, bld, src, input_size, target_size, true); + if (dst.regClass() == v2b) { + if (input_size <= 16) { + /* Expand integer to the size expected by the uint→float converter used below */ + unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32); + if (input_size != target_size) { + src = convert_int(ctx, bld, src, input_size, target_size, true); + } } - } - if (ctx->program->gfx_level >= GFX8 && input_size <= 16) { - bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src); + if (ctx->program->gfx_level >= GFX8 && input_size <= 16) { + bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src); + } else { + /* Large 32bit inputs need to return +-inf/FLOAT_MAX. + * + * This is also the fallback-path taken on GFX7 and earlier, which + * do not support direct f16⟷i16 conversions. + */ + src = bld.vop1(aco_opcode::v_cvt_f32_i32, bld.def(v1), src); + bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + } + } else if (dst.regClass() == s1) { + if (input_size <= 16) { + src = convert_int(ctx, bld, src, input_size, 32, true); + } + src = bld.sop1(aco_opcode::s_cvt_f32_i32, bld.def(s1), src); + bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src); } else { - /* Large 32bit inputs need to return +-inf/FLOAT_MAX. - * - * This is also the fallback-path taken on GFX7 and earlier, which - * do not support direct f16⟷i16 conversions. - */ - src = bld.vop1(aco_opcode::v_cvt_f32_i32, bld.def(v1), src); - bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } break; } @@ -3003,7 +3033,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) /* Sign-extend to 32-bits */ src = convert_int(ctx, bld, src, input_size, 32, true); } - bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src); + if (dst.regClass() == v1) + bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src); + else + bld.sop1(aco_opcode::s_cvt_f32_i32, Definition(dst), src); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -3021,27 +3054,36 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) break; } case nir_op_u2f16: { - assert(dst.regClass() == v2b); Temp src = get_alu_src(ctx, instr->src[0]); const unsigned input_size = instr->src[0].src.ssa->bit_size; - if (input_size <= 16) { - /* Expand integer to the size expected by the uint→float converter used below */ - unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32); - if (input_size != target_size) { - src = convert_int(ctx, bld, src, input_size, target_size, false); + if (dst.regClass() == v2b) { + if (input_size <= 16) { + /* Expand integer to the size expected by the uint→float converter used below */ + unsigned target_size = (ctx->program->gfx_level >= GFX8 ? 16 : 32); + if (input_size != target_size) { + src = convert_int(ctx, bld, src, input_size, target_size, false); + } } - } - if (ctx->program->gfx_level >= GFX8 && input_size <= 16) { - bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src); + if (ctx->program->gfx_level >= GFX8 && input_size <= 16) { + bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src); + } else { + /* Large 32bit inputs need to return inf/FLOAT_MAX. + * + * This is also the fallback-path taken on GFX7 and earlier, which + * do not support direct f16⟷u16 conversions. + */ + src = bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), src); + bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + } + } else if (dst.regClass() == s1) { + if (input_size <= 16) { + src = convert_int(ctx, bld, src, input_size, 32, false); + } + src = bld.sop1(aco_opcode::s_cvt_f32_u32, bld.def(s1), src); + bld.sop1(aco_opcode::s_cvt_f16_f32, Definition(dst), src); } else { - /* Large 32bit inputs need to return inf/FLOAT_MAX. - * - * This is also the fallback-path taken on GFX7 and earlier, which - * do not support direct f16⟷u16 conversions. - */ - src = bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), src); - bld.vop1(aco_opcode::v_cvt_f16_f32, Definition(dst), src); + isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } break; } @@ -3049,12 +3091,15 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) assert(dst.size() == 1); Temp src = get_alu_src(ctx, instr->src[0]); const unsigned input_size = instr->src[0].src.ssa->bit_size; - if (input_size == 8) { + if (input_size == 8 && dst.regClass() == v1) { bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src); } else if (input_size <= 32) { - if (input_size == 16) + if (input_size <= 16) src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, false); - bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src); + if (dst.regClass() == v1) + bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src); + else + bld.sop1(aco_opcode::s_cvt_f32_u32, Definition(dst), src); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } @@ -3416,6 +3461,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } case nir_op_unpack_half_2x16_split_x: { Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == s1) { + bld.sop1(aco_opcode::s_cvt_f32_f16, Definition(dst), src); + break; + } if (src.regClass() == v1) src = bld.pseudo(aco_opcode::p_split_vector, bld.def(v2b), bld.def(v2b), src); if (dst.regClass() == v1) { @@ -3427,6 +3476,10 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } case nir_op_unpack_half_2x16_split_y: { Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == s1) { + bld.sop1(aco_opcode::s_cvt_hi_f32_f16, Definition(dst), src); + break; + } if (src.regClass() == s1) src = bld.pseudo(aco_opcode::p_extract, bld.def(s1), bld.def(s1, scc), src, Operand::c32(1u), Operand::c32(16u), Operand::zero()); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index 8d8e187f053..65b85261781 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -347,16 +347,8 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_flog2: case nir_op_fsin_amd: case nir_op_fcos_amd: - case nir_op_f2f16: - case nir_op_f2f16_rtz: - case nir_op_f2f16_rtne: - case nir_op_f2f32: case nir_op_f2f64: - case nir_op_u2f16: - case nir_op_u2f32: case nir_op_u2f64: - case nir_op_i2f16: - case nir_op_i2f32: case nir_op_i2f64: case nir_op_pack_half_2x16_rtz_split: case nir_op_pack_half_2x16_split: @@ -364,8 +356,6 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_pack_snorm_2x16: case nir_op_pack_uint_2x16: case nir_op_pack_sint_2x16: - case nir_op_unpack_half_2x16_split_x: - case nir_op_unpack_half_2x16_split_y: case nir_op_fddx: case nir_op_fddy: case nir_op_fddx_fine: @@ -389,11 +379,21 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_sdot_2x16_iadd: case nir_op_udot_2x16_uadd_sat: case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break; + case nir_op_i2f16: + case nir_op_i2f32: + case nir_op_u2f16: + case nir_op_u2f32: + case nir_op_f2f16: + case nir_op_f2f16_rtz: + case nir_op_f2f16_rtne: + case nir_op_f2f32: case nir_op_ffract: case nir_op_ffloor: case nir_op_fceil: case nir_op_ftrunc: - case nir_op_fround_even: { + case nir_op_fround_even: + case nir_op_unpack_half_2x16_split_x: + case nir_op_unpack_half_2x16_split_y: { if (ctx->program->gfx_level < GFX11_5 || alu_instr->src[0].src.ssa->bit_size > 32) { type = RegType::vgpr;