diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index f34fd756e85..b12ce1c2649 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1065,7 +1065,6 @@ emit_sopc_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Te assert(dst.regClass() == bld.lm); assert(src0.type() == RegType::sgpr); assert(src1.type() == RegType::sgpr); - assert(src0.regClass() == src1.regClass()); /* Emit the SALU comparison instruction */ Temp cmp = bld.sopc(op, bld.scc(bld.def(s1)), src0, src1); @@ -3916,6 +3915,138 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); break; } + case nir_op_bitz: + case nir_op_bitnz: { + assert(instr->src[0].src.ssa->bit_size != 1); + bool test0 = instr->op == nir_op_bitz; + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + bool use_valu = src0.type() == RegType::vgpr || src1.type() == RegType::vgpr; + if (!use_valu) { + aco_opcode op = instr->src[0].src.ssa->bit_size == 64 ? aco_opcode::s_bitcmp1_b64 + : aco_opcode::s_bitcmp1_b32; + if (test0) + op = instr->src[0].src.ssa->bit_size == 64 ? aco_opcode::s_bitcmp0_b64 + : aco_opcode::s_bitcmp0_b32; + emit_sopc_instruction(ctx, instr, op, dst); + break; + } + + /* We do not have a VALU version of s_bitcmp. + * But if the second source is constant, we can use + * v_cmp_class_f32's LUT to check the bit. + * The LUT only has 10 entries, so extract a higher byte if we have to. + * For sign bits comparision with 0 is better because v_cmp_class + * can't be inverted. + */ + if (nir_src_is_const(instr->src[1].src)) { + uint32_t bit = nir_alu_src_as_uint(instr->src[1]); + bit &= instr->src[0].src.ssa->bit_size - 1; + src0 = as_vgpr(ctx, src0); + + if (src0.regClass() == v2) { + src0 = emit_extract_vector(ctx, src0, (bit & 32) != 0, v1); + bit &= 31; + } + + if (bit == 31) { + bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), + Operand::c32(0), src0); + break; + } + + if (bit == 15 && ctx->program->gfx_level >= GFX8) { + bld.vopc(test0 ? aco_opcode::v_cmp_le_i16 : aco_opcode::v_cmp_gt_i16, Definition(dst), + Operand::c32(0), src0); + break; + } + + /* For the bit==7 case, this is only faster than v_cmp_class if test0 && can_sdwa. */ + const bool can_sdwa = ctx->program->gfx_level >= GFX8 && ctx->program->gfx_level < GFX11; + if ((bit & 0x7) == 7 && ((test0 && can_sdwa) || bit != 7)) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), + Operand::c32(8), Operand::c32(1)); + bld.vopc(test0 ? aco_opcode::v_cmp_le_i32 : aco_opcode::v_cmp_gt_i32, Definition(dst), + Operand::c32(0), src0); + break; + } + + /* Avoid snan for bit 24. */ + if (bit == 24) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(1), + Operand::c32(16), Operand::c32(0)); + bit &= 0xf; + } + + /* avoid +inf if we can use sdwa+qnan */ + if (bit > (can_sdwa ? 0x8 : 0x9)) { + src0 = bld.pseudo(aco_opcode::p_extract, bld.def(v1), src0, Operand::c32(bit / 8), + Operand::c32(8), Operand::c32(0)); + bit &= 0x7; + } + + /* denorm and snan/qnan inputs are preserved using all float control modes. */ + static const std::pair float_lut[10] = { + {0x7f800001, false}, /* snan */ + {-1, false}, /* qnan */ + {0xff800000, false}, /* -inf */ + {0xbf800000, false}, /* -normal (-1.0) */ + {1, true}, /* -denormal */ + {0, true}, /* -0.0 */ + {0, false}, /* +0.0 */ + {1, false}, /* +denormal */ + {0x3f800000, false}, /* +normal (+1.0) */ + {0x7f800000, false}, /* +inf */ + }; + + Temp tmp = test0 ? bld.tmp(bld.lm) : dst; + if (ctx->program->gfx_level >= GFX8 && bit == 0) { + /* this can use s_movk. */ + bld.vopc(aco_opcode::v_cmp_class_f16, Definition(tmp), + bld.copy(bld.def(s1), Operand::c32(0x7c01)), src0); + } else { + VALU_instruction& res = + bld.vopc(aco_opcode::v_cmp_class_f32, Definition(tmp), + bld.copy(bld.def(s1), Operand::c32(float_lut[bit].first)), src0) + ->valu(); + if (float_lut[bit].second) { + res.format = asVOP3(res.format); + res.neg[0] = true; + } + } + + if (test0) + bld.sop1(Builder::s_not, Definition(dst), bld.def(s1, scc), tmp); + + break; + } + + Temp res; + aco_opcode op = test0 ? aco_opcode::v_cmp_eq_i32 : aco_opcode::v_cmp_lg_i32; + if (instr->src[0].src.ssa->bit_size == 16) { + op = test0 ? aco_opcode::v_cmp_eq_i16 : aco_opcode::v_cmp_lg_i16; + if (ctx->program->gfx_level < GFX10) + res = bld.vop2_e64(aco_opcode::v_lshlrev_b16, bld.def(v2b), src1, Operand::c32(1)); + else + res = bld.vop3(aco_opcode::v_lshlrev_b16_e64, bld.def(v2b), src1, Operand::c32(1)); + + res = bld.vop2(aco_opcode::v_and_b32, bld.def(v2b), src0, src1); + } else if (instr->src[0].src.ssa->bit_size == 32) { + res = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), src0, src1, Operand::c32(1)); + } else if (instr->src[0].src.ssa->bit_size == 64) { + if (ctx->program->gfx_level < GFX8) + res = bld.vop3(aco_opcode::v_lshr_b64, bld.def(v2), src0, src1); + else + res = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), src1, src0); + + res = emit_extract_vector(ctx, res, 0, v1); + res = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand::c32(0x1), res); + } else { + isel_err(&instr->instr, "Unimplemented NIR instr bit size"); + } + bld.vopc(op, Definition(dst), Operand::c32(0), res); + break; + } case nir_op_fddx: case nir_op_fddy: case nir_op_fddx_fine: