diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index aa8f1192933..7c26c711f0d 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -311,6 +311,7 @@ struct opt_ctx { std::vector info; std::vector> pre_combine_instrs; std::vector uses; + std::unordered_map> replacement_instr; }; aco_type @@ -2975,35 +2976,6 @@ original_temp_id(opt_ctx& ctx, Temp tmp) return tmp.id(); } -Instruction* -follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false) -{ - if (!op.isTemp()) - return nullptr; - if (!ignore_uses && ctx.uses[op.tempId()] > 1) - return nullptr; - - Instruction* instr = ctx.info[op.tempId()].parent_instr; - - if (instr->definitions[0].getTemp() != op.getTemp()) - return nullptr; - - if (instr->definitions.size() == 2) { - unsigned idx = - instr->definitions[1].isTemp() && instr->definitions[1].tempId() == op.tempId(); - assert(instr->definitions[idx].isTemp() && instr->definitions[idx].tempId() == op.tempId()); - if (instr->definitions[!idx].isTemp() && ctx.uses[instr->definitions[!idx].tempId()]) - return nullptr; - } - - for (Operand& operand : instr->operands) { - if (fixed_to_exec(operand)) - return nullptr; - } - - return instr; -} - bool is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value) { @@ -3372,53 +3344,17 @@ match_and_apply_patterns(opt_ctx& ctx, alu_opt_info& info, return false; } -/* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */ -bool -combine_inverse_comparison(opt_ctx& ctx, aco_ptr& instr) -{ - if (ctx.uses[instr->definitions[1].tempId()]) - return false; - if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1) - return false; - - Instruction* cmp = follow_operand(ctx, instr->operands[0]); - if (!cmp) - return false; - - aco_opcode new_opcode = get_vcmp_inverse(cmp->opcode); - if (new_opcode == aco_opcode::num_opcodes) - return false; - - /* Invert compare instruction and assign this instruction's definition */ - cmp->opcode = new_opcode; - ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()]; - std::swap(instr->definitions[0], cmp->definitions[0]); - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - ctx.info[cmp->definitions[0].tempId()].parent_instr = cmp; - - ctx.uses[instr->operands[0].tempId()]--; - return true; -} - /* v_not(v_xor(a, b)) -> v_xnor(a, b) */ -bool -combine_not_xor(opt_ctx& ctx, aco_ptr& instr) +Instruction* +apply_v_not(opt_ctx& ctx, aco_ptr& instr, Instruction* op_instr) { - if (instr->usesModifiers()) - return false; + if (ctx.program->gfx_level < GFX10 || instr->usesModifiers() || + op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA()) + return nullptr; - Instruction* op_instr = follow_operand(ctx, instr->operands[0]); - if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA()) - return false; - - ctx.uses[instr->operands[0].tempId()]--; - std::swap(instr->definitions[0], op_instr->definitions[0]); + op_instr->definitions[0] = instr->definitions[0]; op_instr->opcode = aco_opcode::v_xnor_b32; - ctx.info[op_instr->definitions[0].tempId()].label = 0; - ctx.info[op_instr->definitions[0].tempId()].parent_instr = op_instr; - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - - return true; + return op_instr; } /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b) @@ -3426,61 +3362,47 @@ combine_not_xor(opt_ctx& ctx, aco_ptr& instr) * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b) * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b) * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b) - * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */ -bool -combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr& instr) + * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) + * s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */ +Instruction* +apply_s_not(opt_ctx& ctx, aco_ptr& instr, Instruction* op_instr) { - /* checks */ - if (!instr->operands[0].isTemp()) - return false; - if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()]) - return false; + if (op_instr->definitions.size() == 1 && ctx.uses[instr->definitions[1].tempId()]) + return nullptr; + else if (op_instr->definitions.size() == 2 && ctx.uses[op_instr->definitions[1].tempId()]) + return nullptr; - Instruction* op2_instr = follow_operand(ctx, instr->operands[0]); - if (!op2_instr) - return false; - switch (op2_instr->opcode) { - case aco_opcode::s_and_b32: - case aco_opcode::s_or_b32: - case aco_opcode::s_xor_b32: - case aco_opcode::s_and_b64: - case aco_opcode::s_or_b64: - case aco_opcode::s_xor_b64: break; - default: return false; + switch (op_instr->opcode) { + case aco_opcode::s_and_b32: op_instr->opcode = aco_opcode::s_nand_b32; break; + case aco_opcode::s_or_b32: op_instr->opcode = aco_opcode::s_nor_b32; break; + case aco_opcode::s_xor_b32: op_instr->opcode = aco_opcode::s_xnor_b32; break; + case aco_opcode::s_and_b64: op_instr->opcode = aco_opcode::s_nand_b64; break; + case aco_opcode::s_or_b64: op_instr->opcode = aco_opcode::s_nor_b64; break; + case aco_opcode::s_xor_b64: op_instr->opcode = aco_opcode::s_xnor_b64; break; + default: { + if (!op_instr->isVOPC()) + return nullptr; + aco_opcode new_opcode = get_vcmp_inverse(op_instr->opcode); + if (new_opcode == aco_opcode::num_opcodes) + return nullptr; + op_instr->opcode = new_opcode; + } } - /* create instruction */ - std::swap(instr->definitions[0], op2_instr->definitions[0]); - std::swap(instr->definitions[1], op2_instr->definitions[1]); - ctx.uses[instr->operands[0].tempId()]--; - ctx.info[op2_instr->definitions[0].tempId()].label = 0; - ctx.info[op2_instr->definitions[0].tempId()].parent_instr = op2_instr; - ctx.info[op2_instr->definitions[1].tempId()].parent_instr = op2_instr; - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - ctx.info[instr->definitions[1].tempId()].parent_instr = instr.get(); + for (unsigned i = 0; i < op_instr->definitions.size(); i++) + op_instr->definitions[i] = instr->definitions[i]; - switch (op2_instr->opcode) { - case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break; - case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break; - case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break; - case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break; - case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break; - case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break; - default: break; - } - - return true; + return op_instr; } /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b) * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b) */ -bool -combine_sabsdiff(opt_ctx& ctx, aco_ptr& instr) +Instruction* +apply_s_abs(opt_ctx& ctx, aco_ptr& instr, Instruction* op_instr) { - Instruction* op_instr = follow_operand(ctx, instr->operands[0], false); - if (!op_instr) - return false; + if (op_instr->definitions.size() != 2 || ctx.uses[op_instr->definitions[1].tempId()]) + return nullptr; if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) { for (unsigned i = 0; i < 2; i++) { @@ -3489,30 +3411,21 @@ combine_sabsdiff(opt_ctx& ctx, aco_ptr& instr) !is_operand_constant(ctx, op_instr->operands[i], 32, &constant)) continue; - if (op_instr->operands[i].isTemp()) - ctx.uses[op_instr->operands[i].tempId()]--; op_instr->operands[0] = op_instr->operands[!i]; op_instr->operands[1] = Operand::c32(-int32_t(constant)); goto use_absdiff; } - return false; + return nullptr; } else if (op_instr->opcode != aco_opcode::s_sub_i32 && op_instr->opcode != aco_opcode::s_sub_u32) { - return false; + return nullptr; } use_absdiff: op_instr->opcode = aco_opcode::s_absdiff_i32; - std::swap(instr->definitions[0], op_instr->definitions[0]); - std::swap(instr->definitions[1], op_instr->definitions[1]); - ctx.uses[instr->operands[0].tempId()]--; - ctx.info[op_instr->definitions[0].tempId()].label = 0; - ctx.info[op_instr->definitions[0].tempId()].parent_instr = op_instr; - ctx.info[op_instr->definitions[1].tempId()].parent_instr = op_instr; - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - ctx.info[instr->definitions[1].tempId()].parent_instr = instr.get(); - - return true; + op_instr->definitions[0] = instr->definitions[0]; + op_instr->definitions[1] = instr->definitions[1]; + return op_instr; } bool @@ -3654,18 +3567,9 @@ apply_insert(opt_ctx& ctx, aco_ptr& instr) /* Remove superfluous extract after ds_read like so: * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN() */ -bool -apply_load_extract(opt_ctx& ctx, aco_ptr& extract) +Instruction* +apply_load_extract(opt_ctx& ctx, aco_ptr& extract, Instruction* load) { - /* Check if p_extract has a usedef operand and is the only user. */ - if (ctx.uses[extract->operands[0].tempId()] > 1) - return false; - - /* Check if the usedef is the right format. */ - Instruction* load = ctx.info[extract->operands[0].tempId()].parent_instr; - if (!load->isDS() && !load->isSMEM() && !load->isMUBUF() && !load->isFlatLike()) - return false; - unsigned extract_idx = extract->operands[1].constantValue(); unsigned bits_extracted = extract->operands[2].constantValue(); bool sign_ext = extract->operands[3].constantValue(); @@ -3698,17 +3602,17 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract) case aco_opcode::s_buffer_load_ushort: case aco_opcode::buffer_load_ushort: case aco_opcode::buffer_load_short_d16: bits_loaded = 16; break; - default: return false; + default: return nullptr; } /* TODO: These are doable, but probably don't occur too often. */ if (extract_idx || bits_extracted > bits_loaded || dst_bitsize > 32 || (load->definitions[0].regClass().type() != extract->definitions[0].regClass().type())) - return false; + return nullptr; /* We can't shrink some loads because that would remove zeroing of the offset/address LSBs. */ if (!can_shrink && bits_extracted < bits_loaded) - return false; + return nullptr; /* Shrink the load if the extracted bit size is smaller. */ bits_loaded = MIN2(bits_loaded, bits_extracted); @@ -3774,12 +3678,8 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract) } /* The load now produces the exact same thing as the extract, remove the extract. */ - std::swap(load->definitions[0], extract->definitions[0]); - ctx.uses[extract->definitions[0].tempId()] = 0; - ctx.info[load->definitions[0].tempId()].label = 0; - ctx.info[extract->definitions[0].tempId()].parent_instr = extract.get(); - ctx.info[load->definitions[0].tempId()].parent_instr = load; - return true; + load->definitions[0] = extract->definitions[0]; + return load; } void @@ -3940,6 +3840,102 @@ op_info_get_constant(opt_ctx& ctx, alu_opt_op op_info, aco_type type, uint64_t* return true; } +Instruction* +apply_output_impl(opt_ctx& ctx, aco_ptr& instr, Instruction* parent) +{ + if (instr->opcode == aco_opcode::p_extract && + (parent->isDS() || parent->isSMEM() || parent->isMUBUF() || parent->isFlatLike())) + return apply_load_extract(ctx, instr, parent); + else if (instr->opcode == aco_opcode::p_extract) + return nullptr; + else if (instr->opcode == aco_opcode::v_not_b32) + return apply_v_not(ctx, instr, parent); + else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) + return apply_s_not(ctx, instr, parent); + else if (instr->opcode == aco_opcode::s_abs_i32) + return apply_s_abs(ctx, instr, parent); + else + UNREACHABLE("unhandled opcode"); + + return nullptr; +} + +bool +apply_output(opt_ctx& ctx, aco_ptr& instr) +{ + switch (instr->opcode) { + case aco_opcode::p_extract: + case aco_opcode::v_not_b32: + case aco_opcode::s_not_b32: + case aco_opcode::s_not_b64: + case aco_opcode::s_abs_i32: break; + default: return false; + } + + int temp_idx = -1; + for (unsigned i = 0; i < instr->operands.size(); i++) { + if (temp_idx < 0 && instr->operands[i].isTemp()) + temp_idx = i; + else if (instr->operands[i].isConstant()) + continue; + else + return false; + } + + if (temp_idx < 0) + return false; + + unsigned tmpid = instr->operands[temp_idx].tempId(); + Instruction* parent = ctx.info[tmpid].parent_instr; + if (ctx.uses[tmpid] != 1 || parent->definitions[0].tempId() != tmpid) + return false; + + int64_t alt_idx = ctx.info[tmpid].is_combined() ? ctx.info[tmpid].val : -1; + aco::small_vec pre_opt_ops; + for (const Operand& op : parent->operands) + pre_opt_ops.push_back(op); + + Instruction* new_instr = apply_output_impl(ctx, instr, parent); + + if (new_instr == nullptr) + return false; + + for (const Operand& op : parent->operands) { + if (op.isTemp()) + ctx.uses[op.tempId()]++; + } + for (const Operand& op : pre_opt_ops) { + if (op.isTemp()) + decrease_and_dce(ctx, op.getTemp()); + } + + ctx.uses[tmpid] = 0; + ctx.info[tmpid].parent_instr = nullptr; + + if (new_instr != parent) + ctx.replacement_instr.emplace(parent, new_instr); + + if (alt_idx >= 0) { + Instruction* new_pre_combine = + apply_output_impl(ctx, instr, ctx.pre_combine_instrs[alt_idx].get()); + + if (new_pre_combine != ctx.pre_combine_instrs[alt_idx].get()) + ctx.pre_combine_instrs[alt_idx].reset(new_pre_combine); + + if (new_pre_combine) + ctx.info[new_instr->definitions[0].tempId()].set_combined(alt_idx); + } + + for (Definition& def : new_instr->definitions) { + ctx.info[def.tempId()].parent_instr = new_instr; + ctx.info[def.tempId()].label &= + instr_mod_labels | canonicalized_labels | label_combined_instr; + } + + instr.reset(); + return true; +} + bool create_fma_cb(opt_ctx& ctx, alu_opt_info& info) { @@ -4165,9 +4161,11 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) if (instr->isDPP()) return; - if (instr->opcode == aco_opcode::p_extract) { - apply_load_extract(ctx, instr); - } + if (!instr->isVALU() && !instr->isSALU() && !instr->isPseudo()) + return; + + if (apply_output(ctx, instr)) + return; /* TODO: There are still some peephole optimizations that could be done: * - abs(a - b) -> s_absdiff_i32 @@ -4230,15 +4228,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) return; } - if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) { - combine_not_xor(ctx, instr); - } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) { - if (!combine_salu_not_bitwise(ctx, instr)) - combine_inverse_comparison(ctx, instr); - } else if (instr->opcode == aco_opcode::s_abs_i32) { - combine_sabsdiff(ctx, instr); - } - alu_opt_info info; if (!alu_opt_gather_info(ctx, instr.get(), info)) return; @@ -4744,12 +4733,29 @@ to_uniform_bool_instr(opt_ctx& ctx, aco_ptr& instr) return true; } +void +insert_replacement_instr(opt_ctx& ctx, aco_ptr& instr) +{ + if (!instr.get() || instr->definitions.empty() || + ctx.info[instr->definitions[0].tempId()].parent_instr == instr.get()) + return; + + while (true) { + auto it = ctx.replacement_instr.find(instr.get()); + if (it == ctx.replacement_instr.end()) + return; + + instr = std::move(it->second); + ctx.replacement_instr.erase(it); + } +} + void select_instruction(opt_ctx& ctx, aco_ptr& instr) { const uint32_t threshold = 4; - if (is_dead(ctx.uses, instr.get())) { + if (!instr.get() || is_dead(ctx.uses, instr.get())) { instr.reset(); return; } @@ -4840,8 +4846,12 @@ select_instruction(opt_ctx& ctx, aco_ptr& instr) * no operand instruction was eliminated. */ bool use_prev = std::all_of( - prev_instr->operands.begin(), prev_instr->operands.end(), [&](Operand op) - { return !op.isTemp() || !is_dead(ctx.uses, ctx.info[op.tempId()].parent_instr); }); + prev_instr->operands.begin(), prev_instr->operands.end(), + [&](Operand op) + { + return !op.isTemp() || (ctx.info[op.tempId()].parent_instr && + !is_dead(ctx.uses, ctx.info[op.tempId()].parent_instr)); + }); if (use_prev) { for (const Operand& op : prev_instr->operands) { @@ -5456,6 +5466,14 @@ optimize(Program* program) combine_instruction(ctx, instr); } + if (!ctx.replacement_instr.empty()) { + for (Block& block : program->blocks) { + ctx.fp_mode = block.fp_mode; + for (aco_ptr& instr : block.instructions) + insert_replacement_instr(ctx, instr); + } + } + validate_opt_ctx(ctx); /* 4. Top-Down DAG pass (backward) to select instructions (includes DCE) */