diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index c1c4cb64c22..4652af8ddb9 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -864,6 +864,13 @@ get_inverse(aco_opcode op) return get_cmp_info(op, &info) ? info.inverse : aco_opcode::num_opcodes; } +aco_opcode +get_swapped(aco_opcode op) +{ + CmpInfo info; + return get_cmp_info(op, &info) ? info.swapped : aco_opcode::num_opcodes; +} + aco_opcode get_f32_cmp(aco_opcode op) { diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index d31e53ad300..f57ba4a9268 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -1883,6 +1883,7 @@ bool needs_exec_mask(const Instruction* instr); aco_opcode get_ordered(aco_opcode op); aco_opcode get_unordered(aco_opcode op); aco_opcode get_inverse(aco_opcode op); +aco_opcode get_swapped(aco_opcode op); aco_opcode get_f32_cmp(aco_opcode op); aco_opcode get_vcmpx(aco_opcode op); unsigned get_cmp_bitsize(aco_opcode op); diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index a75ca2a1b4f..a0ff60e366d 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -123,12 +123,13 @@ enum Label { label_f2f32 = 1ull << 37, label_f2f16 = 1ull << 38, label_split = 1ull << 39, + label_subgroup_invocation = 1ull << 40, }; static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | - label_dpp8 | label_f2f32; + label_dpp8 | label_f2f32 | label_subgroup_invocation; static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16; @@ -494,6 +495,14 @@ struct ssa_info { } bool is_split() { return label & label_split; } + + void set_subgroup_invocation(Instruction* label_instr) + { + add_label(label_subgroup_invocation); + instr = label_instr; + } + + bool is_subgroup_invocation() { return label & label_subgroup_invocation; } }; struct opt_ctx { @@ -2125,6 +2134,26 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get()); break; } + case aco_opcode::v_mbcnt_lo_u32_b32: { + if (instr->operands[0].constantEquals(-1) && instr->operands[1].constantEquals(0)) { + if (ctx.program->wave_size == 32) + ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get()); + else + ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get()); + } + break; + } + case aco_opcode::v_mbcnt_hi_u32_b32: + case aco_opcode::v_mbcnt_hi_u32_b32_e64: { + if (instr->operands[0].constantEquals(-1) && instr->operands[1].isTemp() && + ctx.info[instr->operands[1].tempId()].is_usedef()) { + Instruction *usedef_instr = ctx.info[instr->operands[1].tempId()].instr; + if (usedef_instr->opcode == aco_opcode::v_mbcnt_lo_u32_b32 && + usedef_instr->operands[0].constantEquals(-1) && usedef_instr->operands[1].constantEquals(0)) + ctx.info[instr->definitions[0].tempId()].set_subgroup_invocation(instr.get()); + } + break; + } case aco_opcode::v_cvt_f16_f32: { if (instr->operands[0].isTemp()) ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get()); @@ -2370,6 +2399,86 @@ combine_comparison_ordering(opt_ctx& ctx, aco_ptr& instr) return true; } +/* Optimize v_cmp of constant with subgroup invocation to a constant mask. + * Ideally, we can trade v_cmp for a constant (or literal). + * In a less ideal case, we trade v_cmp for a SALU instruction, which is still a win. + */ +bool +optimize_cmp_subgroup_invocation(opt_ctx& ctx, aco_ptr& instr) +{ + /* This optimization only applies to VOPC with 2 operands. */ + if (instr->operands.size() == 2) + return false; + + /* Find the constant operand or return early if there isn't one. */ + const int const_op_idx = instr->operands[0].isConstant() ? 0 : instr->operands[1].isConstant() ? 1 : -1; + if (const_op_idx == -1) + return false; + + /* Find the operand that has the subgroup invocation. */ + const int mbcnt_op_idx = 1 - const_op_idx; + const Operand mbcnt_op = instr->operands[mbcnt_op_idx]; + if (!mbcnt_op.isTemp() || !ctx.info[mbcnt_op.tempId()].is_subgroup_invocation()) + return false; + + /* Adjust opcode so we don't have to care about const_op_idx below. */ + const aco_opcode op = const_op_idx == 0 ? get_swapped(instr->opcode) : instr->opcode; + const unsigned wave_size = ctx.program->wave_size; + const unsigned val = instr->operands[const_op_idx].constantValue(); + + /* Find suitable constant bitmask corresponding to the value. */ + unsigned first_bit = 0, num_bits = 0; + switch (op) { + case aco_opcode::v_cmp_eq_u32: + case aco_opcode::v_cmp_eq_i32: + first_bit = val; + num_bits = val >= wave_size ? 0 : 1; + break; + case aco_opcode::v_cmp_le_u32: + case aco_opcode::v_cmp_le_i32: + first_bit = 0; + num_bits = val >= wave_size ? wave_size : (val + 1); + break; + case aco_opcode::v_cmp_lt_u32: + case aco_opcode::v_cmp_lt_i32: + first_bit = 0; + num_bits = val >= wave_size ? wave_size : val; + break; + case aco_opcode::v_cmp_ge_u32: + case aco_opcode::v_cmp_ge_i32: + first_bit = val; + num_bits = val >= wave_size ? 0 : (wave_size - val); + break; + case aco_opcode::v_cmp_gt_u32: + case aco_opcode::v_cmp_gt_i32: + first_bit = val + 1; + num_bits = val >= wave_size ? 0 : (wave_size - val - 1); + break; + default: + return false; + } + + Instruction *cpy = NULL; + const uint64_t mask = BITFIELD64_RANGE(first_bit, num_bits); + if (wave_size == 64 && mask > 0x7fffffff && mask != -1ull) { + /* Mask can't be represented as a 64-bit constant or literal, use s_bfm_b64. */ + cpy = create_instruction(aco_opcode::s_bfm_b64, Format::SOP2, 2, 1); + cpy->operands[0] = Operand::c32(num_bits); + cpy->operands[1] = Operand::c32(first_bit); + } else { + /* Copy mask as a literal constant. */ + cpy = create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1); + cpy->operands[0] = wave_size == 32 ? Operand::c32((uint32_t)mask) : Operand::c64(mask); + } + + cpy->definitions[0] = instr->definitions[0]; + ctx.info[instr->definitions[0].tempId()].label = 0; + decrease_uses(ctx, ctx.instructions[mbcnt_op_idx].get()); + instr.reset(cpy); + + return true; +} + bool is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value) { @@ -4045,6 +4154,11 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) apply_ds_extract(ctx, instr); } + if (instr->isVOPC()) { + if (optimize_cmp_subgroup_invocation(ctx, instr)) + return; + } + /* TODO: There are still some peephole optimizations that could be done: * - abs(a - b) -> s_absdiff_i32 * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32