From 69b5767eee387b11cadbf238fcacb04eace32aaa Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 9 Jan 2025 22:38:41 +0100 Subject: [PATCH] aco/optimizer: use new helpers to create v_fma_mixlo_f16 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Foz-DB Navi21: Totals from 69 (0.07% of 97591) affected shaders: Instrs: 45091 -> 45057 (-0.08%) CodeSize: 244016 -> 243932 (-0.03%); split: -0.12%, +0.09% VGPRs: 1792 -> 1680 (-6.25%) Latency: 133496 -> 133572 (+0.06%); split: -0.03%, +0.09% InvThroughput: 35383 -> 35338 (-0.13%) Copies: 4050 -> 4048 (-0.05%) VALU: 30172 -> 30138 (-0.11%) Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 136 +++++----------------- src/amd/compiler/tests/test_optimizer.cpp | 3 +- 2 files changed, 28 insertions(+), 111 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index b1c9e071f01..7c56ca8f9be 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -70,11 +70,9 @@ enum Label { label_omod4 = 1ull << 34, label_omod5 = 1ull << 35, label_clamp = 1ull << 36, - label_f2f16 = 1ull << 39, }; -static constexpr uint64_t instr_mod_labels = - label_omod2 | label_omod4 | label_omod5 | label_clamp | label_f2f16; +static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp; static constexpr uint64_t input_mod_labels = label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64; @@ -236,16 +234,6 @@ struct ssa_info { bool is_clamp() { return label & label_clamp; } - void set_f2f16(Instruction* conv) - { - if (label & temp_labels) - return; - add_label(label_f2f16); - mod_instr = conv; - } - - bool is_f2f16() { return label & label_f2f16; } - void set_uniform_bitwise() { add_label(label_uniform_bitwise); } bool is_uniform_bitwise() { return label & label_uniform_bitwise; } @@ -3016,11 +3004,6 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].set_extract(); break; } - case aco_opcode::v_cvt_f16_f32: { - if (instr->operands[0].isTemp()) - ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get()); - break; - } default: break; } @@ -3583,7 +3566,7 @@ apply_omod_clamp(opt_ctx& ctx, aco_ptr& instr) instr->valu().clamp = true; instr->definitions[0].swapTemp(def_info.mod_instr->definitions[0]); - ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_f2f16; + ctx.info[instr->definitions[0].tempId()].label &= label_clamp; ctx.uses[def_info.mod_instr->definitions[0].tempId()]--; ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); ctx.info[def_info.mod_instr->definitions[0].tempId()].parent_instr = def_info.mod_instr; @@ -3740,102 +3723,34 @@ apply_load_extract(opt_ctx& ctx, aco_ptr& extract, Instruction* loa return load; } -bool -can_use_mad_mix(opt_ctx& ctx, aco_ptr& instr) +Instruction* +apply_f2f16(opt_ctx& ctx, aco_ptr& instr, Instruction* parent) { - if (ctx.program->gfx_level < GFX9) - return false; - - /* unfused v_mad_mix* always flushes 16/32-bit denormal inputs/outputs */ - if (!ctx.program->dev.fused_mad_mix && ctx.fp_mode.denorm) - return false; - if (instr->valu().omod) - return false; + return nullptr; - switch (instr->opcode) { - case aco_opcode::v_add_f32: - case aco_opcode::v_sub_f32: - case aco_opcode::v_subrev_f32: - case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP(); - case aco_opcode::v_fma_f32: - return ctx.program->dev.fused_mad_mix; - case aco_opcode::v_fma_mix_f32: - case aco_opcode::v_fma_mixlo_f16: return true; - default: return false; - } -} + alu_opt_info info; + if (!alu_opt_gather_info(ctx, instr.get(), info)) + return nullptr; + aco_type type = {aco_base_type_float, 1, 32}; -void -to_mad_mix(opt_ctx& ctx, aco_ptr& instr) -{ - ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp; + alu_opt_info parent_info; + if (!alu_opt_gather_info(ctx, parent, parent_info)) + return nullptr; - if (instr->opcode == aco_opcode::v_fma_f32) { - instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P)); - instr->opcode = aco_opcode::v_fma_mix_f32; - return; - } + if (parent_info.uses_insert() || parent_info.f32_to_f16) + return nullptr; - bool is_add = instr->opcode != aco_opcode::v_mul_f32; + if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[0], type)) + return nullptr; - aco_ptr vop3p{create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)}; + parent_info.f32_to_f16 = true; + parent_info.clamp |= info.clamp; - for (unsigned i = 0; i < instr->operands.size(); i++) { - vop3p->operands[is_add + i] = instr->operands[i]; - vop3p->valu().neg_lo[is_add + i] = instr->valu().neg[i]; - vop3p->valu().neg_hi[is_add + i] = instr->valu().abs[i]; - } - if (instr->opcode == aco_opcode::v_mul_f32) { - vop3p->operands[2] = Operand::zero(); - vop3p->valu().neg_lo[2] = true; - } else if (is_add) { - vop3p->operands[0] = Operand::c32(0x3f800000); - if (instr->opcode == aco_opcode::v_sub_f32) - vop3p->valu().neg_lo[2] ^= true; - else if (instr->opcode == aco_opcode::v_subrev_f32) - vop3p->valu().neg_lo[1] ^= true; - } - vop3p->definitions[0] = instr->definitions[0]; - vop3p->valu().clamp = instr->valu().clamp; - vop3p->pass_flags = instr->pass_flags; - instr = std::move(vop3p); - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); -} - -bool -combine_output_conversion(opt_ctx& ctx, aco_ptr& instr) -{ - ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; - if (!def_info.is_f2f16()) - return false; - Instruction* conv = def_info.mod_instr; - - if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1) - return false; - - if (conv->usesModifiers()) - return false; - - if (interp_can_become_fma(ctx, instr)) - interp_p2_f32_inreg_to_fma_dpp(instr); - - if (!can_use_mad_mix(ctx, instr)) - return false; - - if (!instr->isVOP3P()) - to_mad_mix(ctx, instr); - - instr->opcode = aco_opcode::v_fma_mixlo_f16; - instr->definitions[0].swapTemp(conv->definitions[0]); - if (conv->definitions[0].isPrecise()) - instr->definitions[0].setPrecise(true); - ctx.info[instr->definitions[0].tempId()].label &= label_clamp; - ctx.uses[conv->definitions[0].tempId()]--; - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - ctx.info[conv->definitions[0].tempId()].parent_instr = conv; - - return true; + parent_info.defs[0].setTemp(info.defs[0].getTemp()); + if (!alu_opt_info_is_valid(ctx, parent_info)) + return nullptr; + return alu_opt_info_to_instr(ctx, parent_info, parent); } bool @@ -3930,6 +3845,8 @@ apply_output_impl(opt_ctx& ctx, aco_ptr& instr, Instruction* parent instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16 || instr->opcode == aco_opcode::v_pk_mul_f16) return apply_output_mul(ctx, instr, parent); + else if (instr->opcode == aco_opcode::v_cvt_f16_f32) + return apply_f2f16(ctx, instr, parent); else UNREACHABLE("unhandled opcode"); @@ -3950,7 +3867,8 @@ apply_output(opt_ctx& ctx, aco_ptr& instr) case aco_opcode::v_mul_f64_e64: case aco_opcode::v_mul_f32: case aco_opcode::v_mul_f16: - case aco_opcode::v_pk_mul_f16: break; + case aco_opcode::v_pk_mul_f16: + case aco_opcode::v_cvt_f16_f32: break; default: return false; } @@ -4215,7 +4133,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) } if (instr->isVALU()) { - while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr)) + while (apply_omod_clamp(ctx, instr)) ; } diff --git a/src/amd/compiler/tests/test_optimizer.cpp b/src/amd/compiler/tests/test_optimizer.cpp index 67a7e48dd00..579da333361 100644 --- a/src/amd/compiler/tests/test_optimizer.cpp +++ b/src/amd/compiler/tests/test_optimizer.cpp @@ -1308,8 +1308,7 @@ BEGIN_TEST(optimize.mad_mix.output_conv.modifiers) //! p_unit_test 0, %res0 writeout(0, f2f16(fabs(fadd(a, b)))); - //! v1: %res1_add = v_add_f32 %1, %2 - //! v2b: %res1 = v_cvt_f16_f32 -%res1_add + //! v2b: %res1 = v_fma_mixlo_f16 1.0, -%a, -%b //! p_unit_test 1, %res1 writeout(1, f2f16(fneg(fadd(a, b))));