aco/optimizer: use new helpers to create v_fma_mixlo_f16

Foz-DB Navi21:
Totals from 69 (0.07% of 97591) affected shaders:
Instrs: 45091 -> 45057 (-0.08%)
CodeSize: 244016 -> 243932 (-0.03%); split: -0.12%, +0.09%
VGPRs: 1792 -> 1680 (-6.25%)
Latency: 133496 -> 133572 (+0.06%); split: -0.03%, +0.09%
InvThroughput: 35383 -> 35338 (-0.13%)
Copies: 4050 -> 4048 (-0.05%)
VALU: 30172 -> 30138 (-0.11%)

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38658>
This commit is contained in:
Georg Lehmann 2025-01-09 22:38:41 +01:00 committed by Marge Bot
parent ee28801eae
commit 69b5767eee
2 changed files with 28 additions and 111 deletions

View file

@ -70,11 +70,9 @@ enum Label {
label_omod4 = 1ull << 34, label_omod4 = 1ull << 34,
label_omod5 = 1ull << 35, label_omod5 = 1ull << 35,
label_clamp = 1ull << 36, label_clamp = 1ull << 36,
label_f2f16 = 1ull << 39,
}; };
static constexpr uint64_t instr_mod_labels = static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp;
label_omod2 | label_omod4 | label_omod5 | label_clamp | label_f2f16;
static constexpr uint64_t input_mod_labels = static constexpr uint64_t input_mod_labels =
label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64; label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64;
@ -236,16 +234,6 @@ struct ssa_info {
bool is_clamp() { return label & label_clamp; } bool is_clamp() { return label & label_clamp; }
void set_f2f16(Instruction* conv)
{
if (label & temp_labels)
return;
add_label(label_f2f16);
mod_instr = conv;
}
bool is_f2f16() { return label & label_f2f16; }
void set_uniform_bitwise() { add_label(label_uniform_bitwise); } void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
bool is_uniform_bitwise() { return label & label_uniform_bitwise; } bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
@ -3016,11 +3004,6 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
ctx.info[instr->definitions[0].tempId()].set_extract(); ctx.info[instr->definitions[0].tempId()].set_extract();
break; break;
} }
case aco_opcode::v_cvt_f16_f32: {
if (instr->operands[0].isTemp())
ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
break;
}
default: break; default: break;
} }
@ -3583,7 +3566,7 @@ apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->valu().clamp = true; instr->valu().clamp = true;
instr->definitions[0].swapTemp(def_info.mod_instr->definitions[0]); instr->definitions[0].swapTemp(def_info.mod_instr->definitions[0]);
ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_f2f16; ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
ctx.uses[def_info.mod_instr->definitions[0].tempId()]--; ctx.uses[def_info.mod_instr->definitions[0].tempId()]--;
ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get();
ctx.info[def_info.mod_instr->definitions[0].tempId()].parent_instr = def_info.mod_instr; ctx.info[def_info.mod_instr->definitions[0].tempId()].parent_instr = def_info.mod_instr;
@ -3740,102 +3723,34 @@ apply_load_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract, Instruction* loa
return load; return load;
} }
bool Instruction*
can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr) apply_f2f16(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
{ {
if (ctx.program->gfx_level < GFX9)
return false;
/* unfused v_mad_mix* always flushes 16/32-bit denormal inputs/outputs */
if (!ctx.program->dev.fused_mad_mix && ctx.fp_mode.denorm)
return false;
if (instr->valu().omod) if (instr->valu().omod)
return false; return nullptr;
switch (instr->opcode) { alu_opt_info info;
case aco_opcode::v_add_f32: if (!alu_opt_gather_info(ctx, instr.get(), info))
case aco_opcode::v_sub_f32: return nullptr;
case aco_opcode::v_subrev_f32: aco_type type = {aco_base_type_float, 1, 32};
case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
case aco_opcode::v_fma_f32:
return ctx.program->dev.fused_mad_mix;
case aco_opcode::v_fma_mix_f32:
case aco_opcode::v_fma_mixlo_f16: return true;
default: return false;
}
}
void alu_opt_info parent_info;
to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr) if (!alu_opt_gather_info(ctx, parent, parent_info))
{ return nullptr;
ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp;
if (instr->opcode == aco_opcode::v_fma_f32) { if (parent_info.uses_insert() || parent_info.f32_to_f16)
instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P)); return nullptr;
instr->opcode = aco_opcode::v_fma_mix_f32;
return;
}
bool is_add = instr->opcode != aco_opcode::v_mul_f32; if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[0], type))
return nullptr;
aco_ptr<Instruction> vop3p{create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)}; parent_info.f32_to_f16 = true;
parent_info.clamp |= info.clamp;
for (unsigned i = 0; i < instr->operands.size(); i++) { parent_info.defs[0].setTemp(info.defs[0].getTemp());
vop3p->operands[is_add + i] = instr->operands[i]; if (!alu_opt_info_is_valid(ctx, parent_info))
vop3p->valu().neg_lo[is_add + i] = instr->valu().neg[i]; return nullptr;
vop3p->valu().neg_hi[is_add + i] = instr->valu().abs[i]; return alu_opt_info_to_instr(ctx, parent_info, parent);
}
if (instr->opcode == aco_opcode::v_mul_f32) {
vop3p->operands[2] = Operand::zero();
vop3p->valu().neg_lo[2] = true;
} else if (is_add) {
vop3p->operands[0] = Operand::c32(0x3f800000);
if (instr->opcode == aco_opcode::v_sub_f32)
vop3p->valu().neg_lo[2] ^= true;
else if (instr->opcode == aco_opcode::v_subrev_f32)
vop3p->valu().neg_lo[1] ^= true;
}
vop3p->definitions[0] = instr->definitions[0];
vop3p->valu().clamp = instr->valu().clamp;
vop3p->pass_flags = instr->pass_flags;
instr = std::move(vop3p);
ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get();
}
bool
combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
if (!def_info.is_f2f16())
return false;
Instruction* conv = def_info.mod_instr;
if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
return false;
if (conv->usesModifiers())
return false;
if (interp_can_become_fma(ctx, instr))
interp_p2_f32_inreg_to_fma_dpp(instr);
if (!can_use_mad_mix(ctx, instr))
return false;
if (!instr->isVOP3P())
to_mad_mix(ctx, instr);
instr->opcode = aco_opcode::v_fma_mixlo_f16;
instr->definitions[0].swapTemp(conv->definitions[0]);
if (conv->definitions[0].isPrecise())
instr->definitions[0].setPrecise(true);
ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
ctx.uses[conv->definitions[0].tempId()]--;
ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get();
ctx.info[conv->definitions[0].tempId()].parent_instr = conv;
return true;
} }
bool bool
@ -3930,6 +3845,8 @@ apply_output_impl(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent
instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16 || instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16 ||
instr->opcode == aco_opcode::v_pk_mul_f16) instr->opcode == aco_opcode::v_pk_mul_f16)
return apply_output_mul(ctx, instr, parent); return apply_output_mul(ctx, instr, parent);
else if (instr->opcode == aco_opcode::v_cvt_f16_f32)
return apply_f2f16(ctx, instr, parent);
else else
UNREACHABLE("unhandled opcode"); UNREACHABLE("unhandled opcode");
@ -3950,7 +3867,8 @@ apply_output(opt_ctx& ctx, aco_ptr<Instruction>& instr)
case aco_opcode::v_mul_f64_e64: case aco_opcode::v_mul_f64_e64:
case aco_opcode::v_mul_f32: case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_f16: case aco_opcode::v_mul_f16:
case aco_opcode::v_pk_mul_f16: break; case aco_opcode::v_pk_mul_f16:
case aco_opcode::v_cvt_f16_f32: break;
default: return false; default: return false;
} }
@ -4215,7 +4133,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
} }
if (instr->isVALU()) { if (instr->isVALU()) {
while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr)) while (apply_omod_clamp(ctx, instr))
; ;
} }

View file

@ -1308,8 +1308,7 @@ BEGIN_TEST(optimize.mad_mix.output_conv.modifiers)
//! p_unit_test 0, %res0 //! p_unit_test 0, %res0
writeout(0, f2f16(fabs(fadd(a, b)))); writeout(0, f2f16(fabs(fadd(a, b))));
//! v1: %res1_add = v_add_f32 %1, %2 //! v2b: %res1 = v_fma_mixlo_f16 1.0, -%a, -%b
//! v2b: %res1 = v_cvt_f16_f32 -%res1_add
//! p_unit_test 1, %res1 //! p_unit_test 1, %res1
writeout(1, f2f16(fneg(fadd(a, b)))); writeout(1, f2f16(fneg(fadd(a, b))));