diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 327612b3f6a..211627d0854 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -2072,9 +2072,9 @@ original_temp_id(opt_ctx& ctx, Temp tmp) } void -decrease_uses(opt_ctx& ctx, Instruction* instr) +decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr) { - if (!--ctx.uses[instr->definitions[0].tempId()]) { + if (is_dead(ctx.uses, instr)) { for (const Operand& op : instr->operands) { if (op.isTemp()) ctx.uses[op.tempId()]--; @@ -2082,6 +2082,21 @@ decrease_uses(opt_ctx& ctx, Instruction* instr) } } +void +decrease_uses(opt_ctx& ctx, Instruction* instr) +{ + ctx.uses[instr->definitions[0].tempId()]--; + decrease_op_uses_if_dead(ctx, instr); +} + +Operand +copy_operand(opt_ctx& ctx, Operand op) +{ + if (op.isTemp()) + ctx.uses[op.tempId()]++; + return op; +} + Instruction* follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false) { @@ -2162,11 +2177,6 @@ combine_ordering_test(opt_ctx& ctx, aco_ptr& instr) if (num_sgprs > (ctx.program->gfx_level >= GFX10 ? 2 : 1)) return false; - ctx.uses[op[0].id()]++; - ctx.uses[op[1].id()]++; - decrease_uses(ctx, op_instr[0]); - decrease_uses(ctx, op_instr[1]); - aco_opcode new_op = aco_opcode::num_opcodes; switch (bitsize) { case 16: new_op = is_or ? aco_opcode::v_cmp_u_f16 : aco_opcode::v_cmp_o_f16; break; @@ -2186,10 +2196,13 @@ combine_ordering_test(opt_ctx& ctx, aco_ptr& instr) } else { new_instr = create_instruction(new_op, Format::VOPC, 2, 1); } - new_instr->operands[0] = Operand(op[0]); - new_instr->operands[1] = Operand(op[1]); + new_instr->operands[0] = copy_operand(ctx, Operand(op[0])); + new_instr->operands[1] = copy_operand(ctx, Operand(op[1])); new_instr->definitions[0] = instr->definitions[0]; + decrease_uses(ctx, op_instr[0]); + decrease_uses(ctx, op_instr[1]); + ctx.info[instr->definitions[0].tempId()].label = 0; ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr); @@ -2240,11 +2253,6 @@ combine_comparison_ordering(opt_ctx& ctx, aco_ptr& instr) if (prop_cmp1 != prop_nan0 && prop_cmp1 != prop_nan1) return false; - ctx.uses[cmp->operands[0].tempId()]++; - ctx.uses[cmp->operands[1].tempId()]++; - decrease_uses(ctx, nan_test); - decrease_uses(ctx, cmp); - aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode); Instruction* new_instr; if (cmp->isVOP3()) { @@ -2260,10 +2268,13 @@ combine_comparison_ordering(opt_ctx& ctx, aco_ptr& instr) } else { new_instr = create_instruction(new_op, Format::VOPC, 2, 1); } - new_instr->operands[0] = cmp->operands[0]; - new_instr->operands[1] = cmp->operands[1]; + new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]); + new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]); new_instr->definitions[0] = instr->definitions[0]; + decrease_uses(ctx, nan_test); + decrease_uses(ctx, cmp); + ctx.info[instr->definitions[0].tempId()].label = 0; ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr); @@ -2363,13 +2374,6 @@ combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr& instr) if (is_constant_nan(constant_value, bit_size)) return false; - if (cmp->operands[0].isTemp()) - ctx.uses[cmp->operands[0].tempId()]++; - if (cmp->operands[1].isTemp()) - ctx.uses[cmp->operands[1].tempId()]++; - decrease_uses(ctx, nan_test); - decrease_uses(ctx, cmp); - aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode); Instruction* new_instr; if (cmp->isVOP3()) { @@ -2385,10 +2389,13 @@ combine_constant_comparison_ordering(opt_ctx& ctx, aco_ptr& instr) } else { new_instr = create_instruction(new_op, Format::VOPC, 2, 1); } - new_instr->operands[0] = cmp->operands[0]; - new_instr->operands[1] = cmp->operands[1]; + new_instr->operands[0] = copy_operand(ctx, cmp->operands[0]); + new_instr->operands[1] = copy_operand(ctx, cmp->operands[1]); new_instr->definitions[0] = instr->definitions[0]; + decrease_uses(ctx, nan_test); + decrease_uses(ctx, cmp); + ctx.info[instr->definitions[0].tempId()].label = 0; ctx.info[instr->definitions[0].tempId()].set_vopc(new_instr); @@ -2748,9 +2755,9 @@ combine_salu_lshl_add(opt_ctx& ctx, aco_ptr& instr) instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue()) continue; - ctx.uses[instr->operands[i].tempId()]--; instr->operands[1] = instr->operands[!i]; - instr->operands[0] = op2_instr->operands[0]; + instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]); + decrease_uses(ctx, op2_instr); ctx.info[instr->definitions[0].tempId()].label = 0; instr->opcode = std::array{ @@ -3302,15 +3309,12 @@ combine_and_subbrev(opt_ctx& ctx, aco_ptr& instr) return false; } - ctx.uses[instr->operands[i].tempId()]--; - if (ctx.uses[instr->operands[i].tempId()]) - ctx.uses[op_instr->operands[2].tempId()]++; - new_instr->operands[0] = Operand::zero(); new_instr->operands[1] = instr->operands[!i]; - new_instr->operands[2] = Operand(op_instr->operands[2]); + new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]); new_instr->definitions[0] = instr->definitions[0]; instr = std::move(new_instr); + decrease_uses(ctx, op_instr); ctx.info[instr->definitions[0].tempId()].label = 0; return true; }