diff --git a/src/amd/compiler/aco_optimizer_postRA.cpp b/src/amd/compiler/aco_optimizer_postRA.cpp index 67e8b71a09e..2325b37c503 100644 --- a/src/amd/compiler/aco_optimizer_postRA.cpp +++ b/src/amd/compiler/aco_optimizer_postRA.cpp @@ -284,183 +284,194 @@ try_apply_branch_vcc(pr_opt_ctx& ctx, aco_ptr& instr) } void -try_optimize_scc_nocompare(pr_opt_ctx& ctx, aco_ptr& instr) +try_optimize_to_scc_zero_cmp(pr_opt_ctx& ctx, aco_ptr& instr) { /* We are looking for the following pattern: * * s_bfe_u32 s0, s3, 0x40018 ; outputs SGPR and SCC if the SGPR != 0 * s_cmp_eq_i32 s0, 0 ; comparison between the SGPR and 0 - * s_cbranch_scc0 BB3 ; use the result of the comparison, eg. branch or cselect * * If possible, the above is optimized into: * * s_bfe_u32 s0, s3, 0x40018 ; original instruction - * s_cbranch_scc1 BB3 ; modified to use SCC directly rather than the SGPR with comparison + * s_cmp_eq_i32 scc, 0 ; comparison between the scc and 0 * + * This can then be further optimized by try_optimize_scc_nocompare. + * + * Alternatively, if scc is overwritten between the first instruction and the comparison, + * try to pull down the original instruction to replace the cmp entirely. */ - if (!instr->isSALU() && !instr->isBranch()) + if (!instr->isSOPC() || + (instr->opcode != aco_opcode::s_cmp_eq_u32 && instr->opcode != aco_opcode::s_cmp_eq_i32 && + instr->opcode != aco_opcode::s_cmp_lg_u32 && instr->opcode != aco_opcode::s_cmp_lg_i32 && + instr->opcode != aco_opcode::s_cmp_eq_u64 && instr->opcode != aco_opcode::s_cmp_lg_u64) || + (!instr->operands[0].constantEquals(0) && !instr->operands[1].constantEquals(0)) || + (!instr->operands[0].isTemp() && !instr->operands[1].isTemp())) return; - if (instr->isSOPC() && - (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 || - instr->opcode == aco_opcode::s_cmp_lg_u32 || instr->opcode == aco_opcode::s_cmp_lg_i32 || - instr->opcode == aco_opcode::s_cmp_eq_u64 || instr->opcode == aco_opcode::s_cmp_lg_u64) && - (instr->operands[0].constantEquals(0) || instr->operands[1].constantEquals(0)) && - (instr->operands[0].isTemp() || instr->operands[1].isTemp())) { - /* Make sure the constant is always in operand 1 */ - if (instr->operands[0].isConstant()) - std::swap(instr->operands[0], instr->operands[1]); + /* Make sure the constant is always in operand 1 */ + if (instr->operands[0].isConstant()) + std::swap(instr->operands[0], instr->operands[1]); - /* Find the writer instruction of Operand 0. */ - Idx wr_idx = last_writer_idx(ctx, instr->operands[0]); - if (!wr_idx.found()) - return; + /* Find the writer instruction of Operand 0. */ + Idx wr_idx = last_writer_idx(ctx, instr->operands[0]); + if (!wr_idx.found()) + return; - Instruction* wr_instr = ctx.get(wr_idx); - if (!wr_instr->isSALU() || wr_instr->definitions.size() < 2 || - wr_instr->definitions[1].physReg() != scc) - return; + Instruction* wr_instr = ctx.get(wr_idx); + if (!wr_instr->isSALU() || wr_instr->definitions.size() < 2 || + wr_instr->definitions[1].physReg() != scc) + return; - /* Look for instructions which set SCC := (D != 0) */ - switch (wr_instr->opcode) { - case aco_opcode::s_bfe_i32: - case aco_opcode::s_bfe_i64: - case aco_opcode::s_bfe_u32: - case aco_opcode::s_bfe_u64: - case aco_opcode::s_and_b32: - case aco_opcode::s_and_b64: - case aco_opcode::s_andn2_b32: - case aco_opcode::s_andn2_b64: - case aco_opcode::s_or_b32: - case aco_opcode::s_or_b64: - case aco_opcode::s_orn2_b32: - case aco_opcode::s_orn2_b64: - case aco_opcode::s_xor_b32: - case aco_opcode::s_xor_b64: - case aco_opcode::s_not_b32: - case aco_opcode::s_not_b64: - case aco_opcode::s_nor_b32: - case aco_opcode::s_nor_b64: - case aco_opcode::s_xnor_b32: - case aco_opcode::s_xnor_b64: - case aco_opcode::s_nand_b32: - case aco_opcode::s_nand_b64: - case aco_opcode::s_lshl_b32: - case aco_opcode::s_lshl_b64: - case aco_opcode::s_lshr_b32: - case aco_opcode::s_lshr_b64: - case aco_opcode::s_ashr_i32: - case aco_opcode::s_ashr_i64: - case aco_opcode::s_abs_i32: - case aco_opcode::s_absdiff_i32: break; - default: return; - } - - /* Check whether both SCC and Operand 0 are written by the same instruction. */ - Idx sccwr_idx = last_writer_idx(ctx, scc, s1); - if (wr_idx != sccwr_idx) { - /* Check whether the current instruction is the only user of its first operand. */ - if (ctx.uses[wr_instr->definitions[1].tempId()] || - ctx.uses[wr_instr->definitions[0].tempId()] > 1) - return; - - /* Check whether the operands of the writer are overwritten. */ - for (const Operand& op : wr_instr->operands) { - if (is_overwritten_since(ctx, op, wr_idx)) - return; - } - - aco_opcode pulled_opcode = wr_instr->opcode; - if (instr->opcode == aco_opcode::s_cmp_eq_u32 || - instr->opcode == aco_opcode::s_cmp_eq_i32 || - instr->opcode == aco_opcode::s_cmp_eq_u64) { - /* When s_cmp_eq is used, it effectively inverts the SCC def. - * However, we can't simply invert the opcodes here because that - * would change the meaning of the program. - */ - return; - } - - Definition scc_def = instr->definitions[0]; - ctx.uses[wr_instr->definitions[0].tempId()]--; - - /* Copy the writer instruction, but use SCC from the current instr. - * This means that the original instruction will be eliminated. - */ - if (wr_instr->format == Format::SOP2) { - instr.reset(create_instruction(pulled_opcode, Format::SOP2, 2, 2)); - instr->operands[1] = wr_instr->operands[1]; - } else if (wr_instr->format == Format::SOP1) { - instr.reset(create_instruction(pulled_opcode, Format::SOP1, 1, 2)); - } - instr->definitions[0] = wr_instr->definitions[0]; - instr->definitions[1] = scc_def; - instr->operands[0] = wr_instr->operands[0]; - return; - } - - /* Use the SCC def from wr_instr */ - ctx.uses[instr->operands[0].tempId()]--; - instr->operands[0] = Operand(wr_instr->definitions[1].getTemp()); - instr->operands[0].setFixed(scc); - ctx.uses[instr->operands[0].tempId()]++; - - /* Set the opcode and operand to 32-bit */ - instr->operands[1] = Operand::zero(); - instr->opcode = - (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 || - instr->opcode == aco_opcode::s_cmp_eq_u64) - ? aco_opcode::s_cmp_eq_u32 - : aco_opcode::s_cmp_lg_u32; - } else if ((instr->format == Format::PSEUDO_BRANCH && instr->operands.size() == 1 && - instr->operands[0].physReg() == scc) || - instr->opcode == aco_opcode::s_cselect_b32 || - instr->opcode == aco_opcode::s_cselect_b64) { - - /* For cselect, operand 2 is the SCC condition */ - unsigned scc_op_idx = 0; - if (instr->opcode == aco_opcode::s_cselect_b32 || - instr->opcode == aco_opcode::s_cselect_b64) { - scc_op_idx = 2; - } - - Idx wr_idx = last_writer_idx(ctx, instr->operands[scc_op_idx]); - if (!wr_idx.found()) - return; - - Instruction* wr_instr = ctx.get(wr_idx); - - /* Check if we found the pattern above. */ - if (wr_instr->opcode != aco_opcode::s_cmp_eq_u32 && - wr_instr->opcode != aco_opcode::s_cmp_lg_u32) - return; - if (wr_instr->operands[0].physReg() != scc) - return; - if (!wr_instr->operands[1].constantEquals(0)) - return; - - /* The optimization can be unsafe when there are other users. */ - if (ctx.uses[instr->operands[scc_op_idx].tempId()] > 1) - return; - - if (wr_instr->opcode == aco_opcode::s_cmp_eq_u32) { - /* Flip the meaning of the instruction to correctly use the SCC. */ - if (instr->format == Format::PSEUDO_BRANCH) - instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz - : aco_opcode::p_cbranch_z; - else if (instr->opcode == aco_opcode::s_cselect_b32 || - instr->opcode == aco_opcode::s_cselect_b64) - std::swap(instr->operands[0], instr->operands[1]); - else - unreachable( - "scc_nocompare optimization is only implemented for p_cbranch and s_cselect"); - } - - /* Use the SCC def from the original instruction, not the comparison */ - ctx.uses[instr->operands[scc_op_idx].tempId()]--; - instr->operands[scc_op_idx] = wr_instr->operands[0]; + /* Look for instructions which set SCC := (D != 0) */ + switch (wr_instr->opcode) { + case aco_opcode::s_bfe_i32: + case aco_opcode::s_bfe_i64: + case aco_opcode::s_bfe_u32: + case aco_opcode::s_bfe_u64: + case aco_opcode::s_and_b32: + case aco_opcode::s_and_b64: + case aco_opcode::s_andn2_b32: + case aco_opcode::s_andn2_b64: + case aco_opcode::s_or_b32: + case aco_opcode::s_or_b64: + case aco_opcode::s_orn2_b32: + case aco_opcode::s_orn2_b64: + case aco_opcode::s_xor_b32: + case aco_opcode::s_xor_b64: + case aco_opcode::s_not_b32: + case aco_opcode::s_not_b64: + case aco_opcode::s_nor_b32: + case aco_opcode::s_nor_b64: + case aco_opcode::s_xnor_b32: + case aco_opcode::s_xnor_b64: + case aco_opcode::s_nand_b32: + case aco_opcode::s_nand_b64: + case aco_opcode::s_lshl_b32: + case aco_opcode::s_lshl_b64: + case aco_opcode::s_lshr_b32: + case aco_opcode::s_lshr_b64: + case aco_opcode::s_ashr_i32: + case aco_opcode::s_ashr_i64: + case aco_opcode::s_abs_i32: + case aco_opcode::s_absdiff_i32: break; + default: return; } + + /* Check whether both SCC and Operand 0 are written by the same instruction. */ + Idx sccwr_idx = last_writer_idx(ctx, scc, s1); + if (wr_idx != sccwr_idx) { + /* Check whether the current instruction is the only user of its first operand. */ + if (ctx.uses[wr_instr->definitions[1].tempId()] || + ctx.uses[wr_instr->definitions[0].tempId()] > 1) + return; + + /* Check whether the operands of the writer are overwritten. */ + for (const Operand& op : wr_instr->operands) { + if (is_overwritten_since(ctx, op, wr_idx)) + return; + } + + aco_opcode pulled_opcode = wr_instr->opcode; + if (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 || + instr->opcode == aco_opcode::s_cmp_eq_u64) { + /* When s_cmp_eq is used, it effectively inverts the SCC def. + * However, we can't simply invert the opcodes here because that + * would change the meaning of the program. + */ + return; + } + + Definition scc_def = instr->definitions[0]; + ctx.uses[wr_instr->definitions[0].tempId()]--; + + /* Copy the writer instruction, but use SCC from the current instr. + * This means that the original instruction will be eliminated. + */ + if (wr_instr->format == Format::SOP2) { + instr.reset(create_instruction(pulled_opcode, Format::SOP2, 2, 2)); + instr->operands[1] = wr_instr->operands[1]; + } else if (wr_instr->format == Format::SOP1) { + instr.reset(create_instruction(pulled_opcode, Format::SOP1, 1, 2)); + } + instr->definitions[0] = wr_instr->definitions[0]; + instr->definitions[1] = scc_def; + instr->operands[0] = wr_instr->operands[0]; + return; + } + + /* Use the SCC def from wr_instr */ + ctx.uses[instr->operands[0].tempId()]--; + instr->operands[0] = Operand(wr_instr->definitions[1].getTemp()); + instr->operands[0].setFixed(scc); + ctx.uses[instr->operands[0].tempId()]++; + + /* Set the opcode and operand to 32-bit */ + instr->operands[1] = Operand::zero(); + instr->opcode = + (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 || + instr->opcode == aco_opcode::s_cmp_eq_u64) + ? aco_opcode::s_cmp_eq_u32 + : aco_opcode::s_cmp_lg_u32; +} + +void +try_optimize_scc_nocompare(pr_opt_ctx& ctx, aco_ptr& instr) +{ + /* If we have this pattern: + * s_cmp_eq_i32 scc, 0 ; comparison between scc and 0 + * s_cbranch_scc0 BB3 ; use the result of the comparison, eg. branch or cselect + * + * Turn it into: + * <> ; removed s_cmp + * s_cbranch_scc1 BB3 ; inverted branch + */ + + if ((instr->format != Format::PSEUDO_BRANCH || instr->operands.size() != 1 || + instr->operands[0].physReg() != scc) && + instr->opcode != aco_opcode::s_cselect_b32 && instr->opcode != aco_opcode::s_cselect_b64) + return; + + /* For cselect, operand 2 is the SCC condition */ + unsigned scc_op_idx = 0; + if (instr->opcode == aco_opcode::s_cselect_b32 || instr->opcode == aco_opcode::s_cselect_b64) { + scc_op_idx = 2; + } + + Idx wr_idx = last_writer_idx(ctx, instr->operands[scc_op_idx]); + if (!wr_idx.found()) + return; + + Instruction* wr_instr = ctx.get(wr_idx); + + /* Check if we found the pattern above. */ + if (wr_instr->opcode != aco_opcode::s_cmp_eq_u32 && wr_instr->opcode != aco_opcode::s_cmp_lg_u32) + return; + if (wr_instr->operands[0].physReg() != scc) + return; + if (!wr_instr->operands[1].constantEquals(0)) + return; + + /* The optimization can be unsafe when there are other users. */ + if (ctx.uses[instr->operands[scc_op_idx].tempId()] > 1) + return; + + if (wr_instr->opcode == aco_opcode::s_cmp_eq_u32) { + /* Flip the meaning of the instruction to correctly use the SCC. */ + if (instr->format == Format::PSEUDO_BRANCH) + instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz + : aco_opcode::p_cbranch_z; + else if (instr->opcode == aco_opcode::s_cselect_b32 || + instr->opcode == aco_opcode::s_cselect_b64) + std::swap(instr->operands[0], instr->operands[1]); + else + unreachable("scc_nocompare optimization is only implemented for p_cbranch and s_cselect"); + } + + /* Use the SCC def from the original instruction, not the comparison */ + ctx.uses[instr->operands[scc_op_idx].tempId()]--; + instr->operands[scc_op_idx] = wr_instr->operands[0]; } static bool @@ -1217,6 +1228,8 @@ process_instruction(pr_opt_ctx& ctx, aco_ptr& instr) try_apply_branch_vcc(ctx, instr); + try_optimize_to_scc_zero_cmp(ctx, instr); + try_optimize_scc_nocompare(ctx, instr); try_combine_dpp(ctx, instr);