diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index e78472b8f19..6832dbf78f6 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -2810,6 +2810,68 @@ use_absdiff: return true; } +/* s_cmp_{lg,eq}(s_and(a, s_lshl(1, b)), 0) -> s_bitcmp[10](a, b)*/ +bool +combine_s_bitcmp(opt_ctx& ctx, aco_ptr& instr) +{ + bool lg = false; + bool b64 = false; + switch (instr->opcode) { + case aco_opcode::s_cmp_lg_i32: + case aco_opcode::s_cmp_lg_u32: lg = true; break; + case aco_opcode::s_cmp_eq_i32: + case aco_opcode::s_cmp_eq_u32: break; + case aco_opcode::s_cmp_lg_u64: lg = true; FALLTHROUGH; + case aco_opcode::s_cmp_eq_u64: b64 = true; break; + default: return false; + } + + aco_opcode s_and = b64 ? aco_opcode::s_and_b64 : aco_opcode::s_and_b32; + aco_opcode s_lshl = b64 ? aco_opcode::s_lshl_b64 : aco_opcode::s_lshl_b32; + + for (unsigned cmp_idx = 0; cmp_idx < 2; cmp_idx++) { + Instruction* and_instr = follow_operand(ctx, instr->operands[cmp_idx], false); + if (!and_instr || and_instr->opcode != s_and) + continue; + + for (unsigned and_idx = 0; and_idx < 2; and_idx++) { + Instruction* lshl_instr = follow_operand(ctx, and_instr->operands[and_idx], true); + if (!lshl_instr || lshl_instr->opcode != s_lshl || + !lshl_instr->operands[0].constantEquals(1) || + (lshl_instr->operands[1].isLiteral() && and_instr->operands[!and_idx].isLiteral())) + continue; + + bool test1 = false; + if (instr->operands[!cmp_idx].constantEquals(0)) { + test1 = lg; + } else if (instr->operands[!cmp_idx].isTemp() && + instr->operands[!cmp_idx].tempId() == lshl_instr->definitions[0].tempId()) { + test1 = !lg; + ctx.uses[lshl_instr->definitions[0].tempId()]--; + } else { + continue; + } + + if (test1 && b64) + instr->opcode = aco_opcode::s_bitcmp1_b64; + else if (!test1 && b64) + instr->opcode = aco_opcode::s_bitcmp0_b64; + else if (test1 && !b64) + instr->opcode = aco_opcode::s_bitcmp1_b32; + else + instr->opcode = aco_opcode::s_bitcmp0_b32; + + instr->operands[0] = copy_operand(ctx, and_instr->operands[!and_idx]); + instr->operands[1] = copy_operand(ctx, lshl_instr->operands[1]); + decrease_uses(ctx, and_instr); + decrease_op_uses_if_dead(ctx, lshl_instr); + return true; + } + } + + return false; +} + bool combine_add_sub_b2i(opt_ctx& ctx, aco_ptr& instr, aco_opcode new_op, uint8_t ops) { @@ -4226,6 +4288,13 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) } } else if (instr->opcode == aco_opcode::s_abs_i32) { combine_sabsdiff(ctx, instr); + } else if (instr->opcode == aco_opcode::s_cmp_lg_i32 || + instr->opcode == aco_opcode::s_cmp_lg_u32 || + instr->opcode == aco_opcode::s_cmp_lg_u64 || + instr->opcode == aco_opcode::s_cmp_eq_i32 || + instr->opcode == aco_opcode::s_cmp_eq_u32 || + instr->opcode == aco_opcode::s_cmp_eq_u64) { + combine_s_bitcmp(ctx, instr); } else if (instr->opcode == aco_opcode::v_and_b32) { combine_and_subbrev(ctx, instr); } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {