diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp index 11ff7dc735a..96d342f1a2d 100644 --- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp @@ -3875,115 +3875,113 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) * For sign bits comparision with 0 is better because v_cmp_class * can't be inverted. */ + unsigned bit; if (nir_src_is_const(instr->src[1].src)) { - uint32_t bit = nir_alu_src_as_uint(instr->src[1]); + bit = nir_alu_src_as_uint(instr->src[1]); bit &= instr->src[0].src.ssa->bit_size - 1; - src0 = as_vgpr(ctx, src0); - - if (src0.regClass() == v2) { - src0 = emit_extract_vector(ctx, src0, (bit & 32) != 0, v1); - bit &= 31; - } - - if (bit == 31) { - bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), - Operand::c32(0), src0); + } else { + if (instr->src[0].src.ssa->bit_size == 32) { + Temp res = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), src0, src1, Operand::c32(1)); + aco_opcode op = test0 ? aco_opcode::v_cmp_eq_i32 : aco_opcode::v_cmp_lg_i32; + bld.vopc(op, Definition(dst), Operand::c32(0), res); break; + } else if (instr->src[0].src.ssa->bit_size == 16) { + if (ctx->program->gfx_level < GFX10 && src0.type() != RegType::vgpr) + src0 = bld.vop2_e64(aco_opcode::v_lshrrev_b16, bld.def(v2b), src1, src0); + else if (ctx->program->gfx_level < GFX10) + src0 = bld.vop2(aco_opcode::v_lshrrev_b16, bld.def(v2b), src1, src0); + else + src0 = bld.vop3(aco_opcode::v_lshrrev_b16_e64, bld.def(v2b), src1, src0); + } else if (instr->src[0].src.ssa->bit_size == 64) { + if (ctx->program->gfx_level < GFX8) + src0 = bld.vop3(aco_opcode::v_lshr_b64, bld.def(v2), src0, src1); + else + src0 = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), src1, src0); + } else { + isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } - if (bit == 15 && ctx->program->gfx_level >= GFX8) { - bld.vopc(test0 ? aco_opcode::v_cmp_le_i16 : aco_opcode::v_cmp_gt_i16, Definition(dst), - Operand::c32(0), src0); - break; - } + bit = 0; + } - /* Set max_bit lower to avoid +inf if we can use sdwa+qnan instead. */ - const bool can_sdwa = ctx->program->gfx_level >= GFX8 && ctx->program->gfx_level < GFX11; - const unsigned max_bit = can_sdwa ? 0x8 : 0x9; - const bool use_opsel = bit > 0xf && (bit & 0xf) <= max_bit; - if (use_opsel) { - src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(1), - Operand::c32(16), Operand::c32(0)); - bit &= 0xf; - } + src0 = as_vgpr(ctx, src0); - /* If we can use sdwa the extract is free, while test0's s_not is not. */ - if (bit == 7 && test0 && can_sdwa) { - src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), - Operand::c32(8), Operand::c32(1)); - bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), - Operand::c32(0), src0); - break; - } - - if (bit > max_bit) { - src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), - Operand::c32(8), Operand::c32(0)); - bit &= 0x7; - } - - /* denorm and snan/qnan inputs are preserved using all float control modes. */ - static const struct { - uint32_t fp32; - uint32_t fp16; - bool negate; - } float_lut[10] = { - {0x7f800001, 0x7c01, false}, /* snan */ - {~0u, ~0u, false}, /* qnan */ - {0xff800000, 0xfc00, false}, /* -inf */ - {0xbf800000, 0xbc00, false}, /* -normal (-1.0) */ - {1, 1, true}, /* -denormal */ - {0, 0, true}, /* -0.0 */ - {0, 0, false}, /* +0.0 */ - {1, 1, false}, /* +denormal */ - {0x3f800000, 0x3c00, false}, /* +normal (+1.0) */ - {0x7f800000, 0x7c00, false}, /* +inf */ - }; - - Temp tmp = test0 ? bld.tmp(bld.lm) : dst; - /* fp16 can use s_movk for bit 0. It also supports opsel on gfx11. */ - const bool use_fp16 = (ctx->program->gfx_level >= GFX8 && bit == 0) || - (ctx->program->gfx_level >= GFX11 && use_opsel); - const aco_opcode op = use_fp16 ? aco_opcode::v_cmp_class_f16 : aco_opcode::v_cmp_class_f32; - const uint32_t c = use_fp16 ? float_lut[bit].fp16 : float_lut[bit].fp32; - - VALU_instruction& res = - bld.vopc(op, Definition(tmp), bld.copy(bld.def(s1), Operand::c32(c)), src0)->valu(); - if (float_lut[bit].negate) { - res.format = asVOP3(res.format); - res.neg[0] = true; - } - - if (test0) - bld.sop1(Builder::s_not, Definition(dst), bld.def(s1, scc), tmp); + if (src0.regClass() == v2) { + src0 = emit_extract_vector(ctx, src0, (bit & 32) != 0, v1); + bit &= 31; + } + if (bit == 31) { + bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), + Operand::c32(0), src0); break; } - Temp res; - aco_opcode op = test0 ? aco_opcode::v_cmp_eq_i32 : aco_opcode::v_cmp_lg_i32; - if (instr->src[0].src.ssa->bit_size == 16) { - op = test0 ? aco_opcode::v_cmp_eq_i16 : aco_opcode::v_cmp_lg_i16; - if (ctx->program->gfx_level < GFX10) - res = bld.vop2_e64(aco_opcode::v_lshlrev_b16, bld.def(v2b), src1, Operand::c32(1)); - else - res = bld.vop3(aco_opcode::v_lshlrev_b16_e64, bld.def(v2b), src1, Operand::c32(1)); - - res = bld.vop2(aco_opcode::v_and_b32, bld.def(v2b), src0, res); - } else if (instr->src[0].src.ssa->bit_size == 32) { - res = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), src0, src1, Operand::c32(1)); - } else if (instr->src[0].src.ssa->bit_size == 64) { - if (ctx->program->gfx_level < GFX8) - res = bld.vop3(aco_opcode::v_lshr_b64, bld.def(v2), src0, src1); - else - res = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), src1, src0); - - res = emit_extract_vector(ctx, res, 0, v1); - res = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand::c32(0x1), res); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); + if (bit == 15 && ctx->program->gfx_level >= GFX8) { + bld.vopc(test0 ? aco_opcode::v_cmp_le_i16 : aco_opcode::v_cmp_gt_i16, Definition(dst), + Operand::c32(0), src0); + break; } - bld.vopc(op, Definition(dst), Operand::c32(0), res); + + /* Set max_bit lower to avoid +inf if we can use sdwa+qnan instead. */ + const bool can_sdwa = ctx->program->gfx_level >= GFX8 && ctx->program->gfx_level < GFX11; + const unsigned max_bit = can_sdwa ? 0x8 : 0x9; + const bool use_opsel = bit > 0xf && (bit & 0xf) <= max_bit; + if (use_opsel) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(1), + Operand::c32(16), Operand::c32(0)); + bit &= 0xf; + } + + /* If we can use sdwa the extract is free, while test0's s_not is not. */ + if (bit == 7 && test0 && can_sdwa) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), + Operand::c32(8), Operand::c32(1)); + bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), + Operand::c32(0), src0); + break; + } + + if (bit > max_bit) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), + Operand::c32(8), Operand::c32(0)); + bit &= 0x7; + } + + /* denorm and snan/qnan inputs are preserved using all float control modes. */ + static const struct { + uint32_t fp32; + uint32_t fp16; + bool negate; + } float_lut[10] = { + {0x7f800001, 0x7c01, false}, /* snan */ + {~0u, ~0u, false}, /* qnan */ + {0xff800000, 0xfc00, false}, /* -inf */ + {0xbf800000, 0xbc00, false}, /* -normal (-1.0) */ + {1, 1, true}, /* -denormal */ + {0, 0, true}, /* -0.0 */ + {0, 0, false}, /* +0.0 */ + {1, 1, false}, /* +denormal */ + {0x3f800000, 0x3c00, false}, /* +normal (+1.0) */ + {0x7f800000, 0x7c00, false}, /* +inf */ + }; + + Temp tmp = test0 ? bld.tmp(bld.lm) : dst; + /* fp16 can use s_movk for bit 0. It also supports opsel on gfx11. */ + const bool use_fp16 = (ctx->program->gfx_level >= GFX8 && bit == 0) || + (ctx->program->gfx_level >= GFX11 && use_opsel); + const aco_opcode op = use_fp16 ? aco_opcode::v_cmp_class_f16 : aco_opcode::v_cmp_class_f32; + const uint32_t c = use_fp16 ? float_lut[bit].fp16 : float_lut[bit].fp32; + + VALU_instruction& res = + bld.vopc(op, Definition(tmp), bld.copy(bld.def(s1), Operand::c32(c)), src0)->valu(); + if (float_lut[bit].negate) { + res.format = asVOP3(res.format); + res.neg[0] = true; + } + + if (test0) + bld.sop1(Builder::s_not, Definition(dst), bld.def(s1, scc), tmp); break; } default: isel_err(&instr->instr, "Unknown NIR ALU instr");