diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 79ae746ddf8..aa18c3e5053 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -3222,6 +3222,29 @@ sop2_can_use_sopk(ra_ctx& ctx, Instruction* instr) return true; } +bool +sop2_can_use_bitset(ra_ctx& ctx, Instruction* instr) +{ + if (instr->opcode != aco_opcode::s_and_b32 && instr->opcode != aco_opcode::s_and_b64 && + instr->opcode != aco_opcode::s_or_b32 && instr->opcode != aco_opcode::s_or_b64) + return false; + + uint32_t const_idx = instr->operands[0].isConstant() ? 0 : 1; + if (!instr->definitions[1].isKill() || !instr->operands[const_idx].isConstant() || + !instr->operands[!const_idx].isTemp() || !instr->operands[!const_idx].isKillBeforeDef()) + return false; + + uint64_t val = instr->operands[const_idx].constantValue64(); + + switch (instr->opcode) { + case aco_opcode::s_and_b32: return util_bitcount(val) == 31; + case aco_opcode::s_and_b64: return util_bitcount64(val) == 63; + case aco_opcode::s_or_b32: return util_bitcount(val) == 1; + case aco_opcode::s_or_b64: return util_bitcount64(val) == 1; + default: return false; + } +} + void create_phi_vector_affinities(ra_ctx& ctx, aco_ptr& instr, std::map>& vector_phis) @@ -3370,6 +3393,8 @@ get_affinities(ra_ctx& ctx) op = instr->operands[2]; } else if (i == 0 && sop2_can_use_sopk(ctx, instr.get())) { op = instr->operands[instr->operands[0].isLiteral()]; + } else if (i == 0 && sop2_can_use_bitset(ctx, instr.get())) { + op = instr->operands[instr->operands[0].isConstant()]; } else { continue; } @@ -3583,34 +3608,56 @@ optimize_encoding_vop2(ra_ctx& ctx, RegisterFile& register_file, aco_ptr& instr) +optimize_encoding_sopk_sop1(ra_ctx& ctx, RegisterFile& register_file, aco_ptr& instr) { - /* try to optimize sop2 with literal source to sopk */ - if (!sop2_can_use_sopk(ctx, instr.get())) + /* try to optimize sop2 with literal source to sopk, or s_and/s_or to s_bitset */ + bool sopk = sop2_can_use_sopk(ctx, instr.get()); + bool bitset = sop2_can_use_bitset(ctx, instr.get()); + if (!sopk && !bitset) return; - unsigned literal_idx = instr->operands[1].isLiteral(); + unsigned const_idx = instr->operands[1].isConstant(); - PhysReg op_reg = instr->operands[!literal_idx].physReg(); + PhysReg op_reg = instr->operands[!const_idx].physReg(); if (!is_sgpr_writable_without_side_effects(ctx.program->gfx_level, op_reg)) return; - if (affinity_blocks_tied_def0(ctx, register_file, instr.get(), !literal_idx)) + if (affinity_blocks_tied_def0(ctx, register_file, instr.get(), !const_idx)) return; - instr->format = Format::SOPK; - instr->salu().imm = instr->operands[literal_idx].constantValue() & 0xffff; - if (literal_idx == 0) - std::swap(instr->operands[0], instr->operands[1]); - if (instr->operands.size() > 2) - std::swap(instr->operands[1], instr->operands[2]); - instr->operands.pop_back(); + if (sopk) { + instr->format = Format::SOPK; + instr->salu().imm = instr->operands[const_idx].constantValue() & 0xffff; + if (const_idx == 0) + std::swap(instr->operands[0], instr->operands[1]); + if (instr->operands.size() > 2) + std::swap(instr->operands[1], instr->operands[2]); + instr->operands.pop_back(); - switch (instr->opcode) { - case aco_opcode::s_add_u32: - case aco_opcode::s_add_i32: instr->opcode = aco_opcode::s_addk_i32; break; - case aco_opcode::s_mul_i32: instr->opcode = aco_opcode::s_mulk_i32; break; - case aco_opcode::s_cselect_b32: instr->opcode = aco_opcode::s_cmovk_i32; break; - default: UNREACHABLE("illegal instruction"); + switch (instr->opcode) { + case aco_opcode::s_add_u32: + case aco_opcode::s_add_i32: instr->opcode = aco_opcode::s_addk_i32; break; + case aco_opcode::s_mul_i32: instr->opcode = aco_opcode::s_mulk_i32; break; + case aco_opcode::s_cselect_b32: instr->opcode = aco_opcode::s_cmovk_i32; break; + default: UNREACHABLE("illegal instruction"); + } + } else { + instr->format = Format::SOP1; + if (const_idx == 1) + std::swap(instr->operands[0], instr->operands[1]); + instr->definitions.pop_back(); + + switch (instr->opcode) { + case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_bitset0_b32; break; + case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_bitset0_b64; break; + case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_bitset1_b32; break; + case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_bitset1_b64; break; + default: UNREACHABLE("illegal instruction"); + } + + uint64_t val = instr->operands[0].constantValue64(); + if (instr->opcode == aco_opcode::s_bitset0_b32 || instr->opcode == aco_opcode::s_bitset0_b64) + val = ~val; + instr->operands[0] = Operand::c32(ffsll(val) - 1); } } @@ -3620,7 +3667,7 @@ optimize_encoding(ra_ctx& ctx, RegisterFile& register_file, aco_ptr if (instr->isVALU()) optimize_encoding_vop2(ctx, register_file, instr); if (instr->isSALU()) - optimize_encoding_sopk(ctx, register_file, instr); + optimize_encoding_sopk_sop1(ctx, register_file, instr); } void