aco/opt_postRA: split try_optimize_scc_nocompare in two functions

These are two independent steps, no real reason why they should be in the same
function.

No FOZ-DB changes.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33734>
This commit is contained in:
Georg Lehmann 2025-02-25 11:28:23 +01:00 committed by Marge Bot
parent 9d020826ca
commit 3386ea09d4

View file

@ -284,183 +284,194 @@ try_apply_branch_vcc(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
void
try_optimize_scc_nocompare(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
try_optimize_to_scc_zero_cmp(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
/* We are looking for the following pattern:
*
* s_bfe_u32 s0, s3, 0x40018 ; outputs SGPR and SCC if the SGPR != 0
* s_cmp_eq_i32 s0, 0 ; comparison between the SGPR and 0
* s_cbranch_scc0 BB3 ; use the result of the comparison, eg. branch or cselect
*
* If possible, the above is optimized into:
*
* s_bfe_u32 s0, s3, 0x40018 ; original instruction
* s_cbranch_scc1 BB3 ; modified to use SCC directly rather than the SGPR with comparison
* s_cmp_eq_i32 scc, 0 ; comparison between the scc and 0
*
* This can then be further optimized by try_optimize_scc_nocompare.
*
* Alternatively, if scc is overwritten between the first instruction and the comparison,
* try to pull down the original instruction to replace the cmp entirely.
*/
if (!instr->isSALU() && !instr->isBranch())
if (!instr->isSOPC() ||
(instr->opcode != aco_opcode::s_cmp_eq_u32 && instr->opcode != aco_opcode::s_cmp_eq_i32 &&
instr->opcode != aco_opcode::s_cmp_lg_u32 && instr->opcode != aco_opcode::s_cmp_lg_i32 &&
instr->opcode != aco_opcode::s_cmp_eq_u64 && instr->opcode != aco_opcode::s_cmp_lg_u64) ||
(!instr->operands[0].constantEquals(0) && !instr->operands[1].constantEquals(0)) ||
(!instr->operands[0].isTemp() && !instr->operands[1].isTemp()))
return;
if (instr->isSOPC() &&
(instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 ||
instr->opcode == aco_opcode::s_cmp_lg_u32 || instr->opcode == aco_opcode::s_cmp_lg_i32 ||
instr->opcode == aco_opcode::s_cmp_eq_u64 || instr->opcode == aco_opcode::s_cmp_lg_u64) &&
(instr->operands[0].constantEquals(0) || instr->operands[1].constantEquals(0)) &&
(instr->operands[0].isTemp() || instr->operands[1].isTemp())) {
/* Make sure the constant is always in operand 1 */
if (instr->operands[0].isConstant())
std::swap(instr->operands[0], instr->operands[1]);
/* Make sure the constant is always in operand 1 */
if (instr->operands[0].isConstant())
std::swap(instr->operands[0], instr->operands[1]);
/* Find the writer instruction of Operand 0. */
Idx wr_idx = last_writer_idx(ctx, instr->operands[0]);
if (!wr_idx.found())
return;
/* Find the writer instruction of Operand 0. */
Idx wr_idx = last_writer_idx(ctx, instr->operands[0]);
if (!wr_idx.found())
return;
Instruction* wr_instr = ctx.get(wr_idx);
if (!wr_instr->isSALU() || wr_instr->definitions.size() < 2 ||
wr_instr->definitions[1].physReg() != scc)
return;
Instruction* wr_instr = ctx.get(wr_idx);
if (!wr_instr->isSALU() || wr_instr->definitions.size() < 2 ||
wr_instr->definitions[1].physReg() != scc)
return;
/* Look for instructions which set SCC := (D != 0) */
switch (wr_instr->opcode) {
case aco_opcode::s_bfe_i32:
case aco_opcode::s_bfe_i64:
case aco_opcode::s_bfe_u32:
case aco_opcode::s_bfe_u64:
case aco_opcode::s_and_b32:
case aco_opcode::s_and_b64:
case aco_opcode::s_andn2_b32:
case aco_opcode::s_andn2_b64:
case aco_opcode::s_or_b32:
case aco_opcode::s_or_b64:
case aco_opcode::s_orn2_b32:
case aco_opcode::s_orn2_b64:
case aco_opcode::s_xor_b32:
case aco_opcode::s_xor_b64:
case aco_opcode::s_not_b32:
case aco_opcode::s_not_b64:
case aco_opcode::s_nor_b32:
case aco_opcode::s_nor_b64:
case aco_opcode::s_xnor_b32:
case aco_opcode::s_xnor_b64:
case aco_opcode::s_nand_b32:
case aco_opcode::s_nand_b64:
case aco_opcode::s_lshl_b32:
case aco_opcode::s_lshl_b64:
case aco_opcode::s_lshr_b32:
case aco_opcode::s_lshr_b64:
case aco_opcode::s_ashr_i32:
case aco_opcode::s_ashr_i64:
case aco_opcode::s_abs_i32:
case aco_opcode::s_absdiff_i32: break;
default: return;
}
/* Check whether both SCC and Operand 0 are written by the same instruction. */
Idx sccwr_idx = last_writer_idx(ctx, scc, s1);
if (wr_idx != sccwr_idx) {
/* Check whether the current instruction is the only user of its first operand. */
if (ctx.uses[wr_instr->definitions[1].tempId()] ||
ctx.uses[wr_instr->definitions[0].tempId()] > 1)
return;
/* Check whether the operands of the writer are overwritten. */
for (const Operand& op : wr_instr->operands) {
if (is_overwritten_since(ctx, op, wr_idx))
return;
}
aco_opcode pulled_opcode = wr_instr->opcode;
if (instr->opcode == aco_opcode::s_cmp_eq_u32 ||
instr->opcode == aco_opcode::s_cmp_eq_i32 ||
instr->opcode == aco_opcode::s_cmp_eq_u64) {
/* When s_cmp_eq is used, it effectively inverts the SCC def.
* However, we can't simply invert the opcodes here because that
* would change the meaning of the program.
*/
return;
}
Definition scc_def = instr->definitions[0];
ctx.uses[wr_instr->definitions[0].tempId()]--;
/* Copy the writer instruction, but use SCC from the current instr.
* This means that the original instruction will be eliminated.
*/
if (wr_instr->format == Format::SOP2) {
instr.reset(create_instruction(pulled_opcode, Format::SOP2, 2, 2));
instr->operands[1] = wr_instr->operands[1];
} else if (wr_instr->format == Format::SOP1) {
instr.reset(create_instruction(pulled_opcode, Format::SOP1, 1, 2));
}
instr->definitions[0] = wr_instr->definitions[0];
instr->definitions[1] = scc_def;
instr->operands[0] = wr_instr->operands[0];
return;
}
/* Use the SCC def from wr_instr */
ctx.uses[instr->operands[0].tempId()]--;
instr->operands[0] = Operand(wr_instr->definitions[1].getTemp());
instr->operands[0].setFixed(scc);
ctx.uses[instr->operands[0].tempId()]++;
/* Set the opcode and operand to 32-bit */
instr->operands[1] = Operand::zero();
instr->opcode =
(instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 ||
instr->opcode == aco_opcode::s_cmp_eq_u64)
? aco_opcode::s_cmp_eq_u32
: aco_opcode::s_cmp_lg_u32;
} else if ((instr->format == Format::PSEUDO_BRANCH && instr->operands.size() == 1 &&
instr->operands[0].physReg() == scc) ||
instr->opcode == aco_opcode::s_cselect_b32 ||
instr->opcode == aco_opcode::s_cselect_b64) {
/* For cselect, operand 2 is the SCC condition */
unsigned scc_op_idx = 0;
if (instr->opcode == aco_opcode::s_cselect_b32 ||
instr->opcode == aco_opcode::s_cselect_b64) {
scc_op_idx = 2;
}
Idx wr_idx = last_writer_idx(ctx, instr->operands[scc_op_idx]);
if (!wr_idx.found())
return;
Instruction* wr_instr = ctx.get(wr_idx);
/* Check if we found the pattern above. */
if (wr_instr->opcode != aco_opcode::s_cmp_eq_u32 &&
wr_instr->opcode != aco_opcode::s_cmp_lg_u32)
return;
if (wr_instr->operands[0].physReg() != scc)
return;
if (!wr_instr->operands[1].constantEquals(0))
return;
/* The optimization can be unsafe when there are other users. */
if (ctx.uses[instr->operands[scc_op_idx].tempId()] > 1)
return;
if (wr_instr->opcode == aco_opcode::s_cmp_eq_u32) {
/* Flip the meaning of the instruction to correctly use the SCC. */
if (instr->format == Format::PSEUDO_BRANCH)
instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
: aco_opcode::p_cbranch_z;
else if (instr->opcode == aco_opcode::s_cselect_b32 ||
instr->opcode == aco_opcode::s_cselect_b64)
std::swap(instr->operands[0], instr->operands[1]);
else
unreachable(
"scc_nocompare optimization is only implemented for p_cbranch and s_cselect");
}
/* Use the SCC def from the original instruction, not the comparison */
ctx.uses[instr->operands[scc_op_idx].tempId()]--;
instr->operands[scc_op_idx] = wr_instr->operands[0];
/* Look for instructions which set SCC := (D != 0) */
switch (wr_instr->opcode) {
case aco_opcode::s_bfe_i32:
case aco_opcode::s_bfe_i64:
case aco_opcode::s_bfe_u32:
case aco_opcode::s_bfe_u64:
case aco_opcode::s_and_b32:
case aco_opcode::s_and_b64:
case aco_opcode::s_andn2_b32:
case aco_opcode::s_andn2_b64:
case aco_opcode::s_or_b32:
case aco_opcode::s_or_b64:
case aco_opcode::s_orn2_b32:
case aco_opcode::s_orn2_b64:
case aco_opcode::s_xor_b32:
case aco_opcode::s_xor_b64:
case aco_opcode::s_not_b32:
case aco_opcode::s_not_b64:
case aco_opcode::s_nor_b32:
case aco_opcode::s_nor_b64:
case aco_opcode::s_xnor_b32:
case aco_opcode::s_xnor_b64:
case aco_opcode::s_nand_b32:
case aco_opcode::s_nand_b64:
case aco_opcode::s_lshl_b32:
case aco_opcode::s_lshl_b64:
case aco_opcode::s_lshr_b32:
case aco_opcode::s_lshr_b64:
case aco_opcode::s_ashr_i32:
case aco_opcode::s_ashr_i64:
case aco_opcode::s_abs_i32:
case aco_opcode::s_absdiff_i32: break;
default: return;
}
/* Check whether both SCC and Operand 0 are written by the same instruction. */
Idx sccwr_idx = last_writer_idx(ctx, scc, s1);
if (wr_idx != sccwr_idx) {
/* Check whether the current instruction is the only user of its first operand. */
if (ctx.uses[wr_instr->definitions[1].tempId()] ||
ctx.uses[wr_instr->definitions[0].tempId()] > 1)
return;
/* Check whether the operands of the writer are overwritten. */
for (const Operand& op : wr_instr->operands) {
if (is_overwritten_since(ctx, op, wr_idx))
return;
}
aco_opcode pulled_opcode = wr_instr->opcode;
if (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 ||
instr->opcode == aco_opcode::s_cmp_eq_u64) {
/* When s_cmp_eq is used, it effectively inverts the SCC def.
* However, we can't simply invert the opcodes here because that
* would change the meaning of the program.
*/
return;
}
Definition scc_def = instr->definitions[0];
ctx.uses[wr_instr->definitions[0].tempId()]--;
/* Copy the writer instruction, but use SCC from the current instr.
* This means that the original instruction will be eliminated.
*/
if (wr_instr->format == Format::SOP2) {
instr.reset(create_instruction(pulled_opcode, Format::SOP2, 2, 2));
instr->operands[1] = wr_instr->operands[1];
} else if (wr_instr->format == Format::SOP1) {
instr.reset(create_instruction(pulled_opcode, Format::SOP1, 1, 2));
}
instr->definitions[0] = wr_instr->definitions[0];
instr->definitions[1] = scc_def;
instr->operands[0] = wr_instr->operands[0];
return;
}
/* Use the SCC def from wr_instr */
ctx.uses[instr->operands[0].tempId()]--;
instr->operands[0] = Operand(wr_instr->definitions[1].getTemp());
instr->operands[0].setFixed(scc);
ctx.uses[instr->operands[0].tempId()]++;
/* Set the opcode and operand to 32-bit */
instr->operands[1] = Operand::zero();
instr->opcode =
(instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 ||
instr->opcode == aco_opcode::s_cmp_eq_u64)
? aco_opcode::s_cmp_eq_u32
: aco_opcode::s_cmp_lg_u32;
}
void
try_optimize_scc_nocompare(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
/* If we have this pattern:
* s_cmp_eq_i32 scc, 0 ; comparison between scc and 0
* s_cbranch_scc0 BB3 ; use the result of the comparison, eg. branch or cselect
*
* Turn it into:
* <> ; removed s_cmp
* s_cbranch_scc1 BB3 ; inverted branch
*/
if ((instr->format != Format::PSEUDO_BRANCH || instr->operands.size() != 1 ||
instr->operands[0].physReg() != scc) &&
instr->opcode != aco_opcode::s_cselect_b32 && instr->opcode != aco_opcode::s_cselect_b64)
return;
/* For cselect, operand 2 is the SCC condition */
unsigned scc_op_idx = 0;
if (instr->opcode == aco_opcode::s_cselect_b32 || instr->opcode == aco_opcode::s_cselect_b64) {
scc_op_idx = 2;
}
Idx wr_idx = last_writer_idx(ctx, instr->operands[scc_op_idx]);
if (!wr_idx.found())
return;
Instruction* wr_instr = ctx.get(wr_idx);
/* Check if we found the pattern above. */
if (wr_instr->opcode != aco_opcode::s_cmp_eq_u32 && wr_instr->opcode != aco_opcode::s_cmp_lg_u32)
return;
if (wr_instr->operands[0].physReg() != scc)
return;
if (!wr_instr->operands[1].constantEquals(0))
return;
/* The optimization can be unsafe when there are other users. */
if (ctx.uses[instr->operands[scc_op_idx].tempId()] > 1)
return;
if (wr_instr->opcode == aco_opcode::s_cmp_eq_u32) {
/* Flip the meaning of the instruction to correctly use the SCC. */
if (instr->format == Format::PSEUDO_BRANCH)
instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
: aco_opcode::p_cbranch_z;
else if (instr->opcode == aco_opcode::s_cselect_b32 ||
instr->opcode == aco_opcode::s_cselect_b64)
std::swap(instr->operands[0], instr->operands[1]);
else
unreachable("scc_nocompare optimization is only implemented for p_cbranch and s_cselect");
}
/* Use the SCC def from the original instruction, not the comparison */
ctx.uses[instr->operands[scc_op_idx].tempId()]--;
instr->operands[scc_op_idx] = wr_instr->operands[0];
}
static bool
@ -1217,6 +1228,8 @@ process_instruction(pr_opt_ctx& ctx, aco_ptr<Instruction>& instr)
try_apply_branch_vcc(ctx, instr);
try_optimize_to_scc_zero_cmp(ctx, instr);
try_optimize_scc_nocompare(ctx, instr);
try_combine_dpp(ctx, instr);