aco: simplify multiply-add combining

When both operands of a v_sub (same apply for v_add) are mul and one
already uses clamp/omod, pick the other operand to get a chance to
combine to a MAD.

No fossils-db changes.

Co-authored-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>
This commit is contained in:
Daniel Schürmann 2020-09-02 15:19:21 +01:00 committed by Marge Bot
parent fcd2ef23e5
commit 01134b0bfe

View file

@ -2797,49 +2797,50 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
if (need_fma && mad32 && !ctx.program->has_fast_fma32)
return;
uint32_t uses_src0 = UINT32_MAX;
uint32_t uses_src1 = UINT32_MAX;
Instruction* mul_instr = nullptr;
unsigned add_op_idx;
/* check if any of the operands is a multiplication */
ssa_info *op0_info = instr->operands[0].isTemp() ? &ctx.info[instr->operands[0].tempId()] : NULL;
ssa_info *op1_info = instr->operands[1].isTemp() ? &ctx.info[instr->operands[1].tempId()] : NULL;
if (op0_info && op0_info->is_mul() && (!need_fma || !op0_info->instr->definitions[0].isPrecise()))
uses_src0 = ctx.uses[instr->operands[0].tempId()];
if (op1_info && op1_info->is_mul() && (!need_fma || !op1_info->instr->definitions[0].isPrecise()))
uses_src1 = ctx.uses[instr->operands[1].tempId()];
uint32_t uses = UINT32_MAX;
/* find the 'best' mul instruction to combine with the add */
if (uses_src0 < uses_src1) {
mul_instr = op0_info->instr;
add_op_idx = 1;
} else if (uses_src1 < uses_src0) {
mul_instr = op1_info->instr;
add_op_idx = 0;
} else if (uses_src0 != UINT32_MAX) {
/* tiebreaker: quite random what to pick */
if (op0_info->instr->operands[0].isLiteral()) {
mul_instr = op1_info->instr;
add_op_idx = 0;
} else {
mul_instr = op0_info->instr;
add_op_idx = 1;
}
for (unsigned i = 0; i < 2; i++) {
if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
continue;
/* check precision requirements */
ssa_info& info = ctx.info[instr->operands[i].tempId()];
if (need_fma && info.instr->definitions[0].isPrecise())
continue;
/* no clamp/omod allowed between mul and add */
if (info.instr->isVOP3() &&
(static_cast<VOP3A_instruction*>(info.instr)->clamp ||
static_cast<VOP3A_instruction*>(info.instr)->omod))
continue;
Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
if (info.instr->isSDWA() ||
!check_vop3_operands(ctx, 3, op) ||
ctx.uses[instr->operands[i].tempId()] >= uses)
continue;
mul_instr = info.instr;
add_op_idx = 1 - i;
uses = ctx.uses[instr->operands[i].tempId()];
}
if (mul_instr) {
Operand op[3] = {Operand(v1), Operand(v1), Operand(v1)};
/* turn mul+add into v_mad/v_fma */
Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]};
ctx.uses[mul_instr->definitions[0].tempId()]--;
if (ctx.uses[mul_instr->definitions[0].tempId()]) {
if (op[0].isTemp())
ctx.uses[op[0].tempId()]++;
if (op[1].isTemp())
ctx.uses[op[1].tempId()]++;
}
bool neg[3] = {false, false, false};
bool abs[3] = {false, false, false};
unsigned omod = 0;
bool clamp = false;
op[0] = mul_instr->operands[0];
op[1] = mul_instr->operands[1];
op[2] = instr->operands[add_op_idx];
// TODO: would be better to check this before selecting a mul instr?
if (!check_vop3_operands(ctx, 3, op))
return;
if (mul_instr->isSDWA())
return;
if (mul_instr->isVOP3()) {
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
@ -2847,18 +2848,6 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
neg[1] = vop3->neg[1];
abs[0] = vop3->abs[0];
abs[1] = vop3->abs[1];
/* we cannot use these modifiers between mul and add */
if (vop3->clamp || vop3->omod)
return;
}
/* convert to mad */
ctx.uses[mul_instr->definitions[0].tempId()]--;
if (ctx.uses[mul_instr->definitions[0].tempId()]) {
if (op[0].isTemp())
ctx.uses[op[0].tempId()]++;
if (op[1].isTemp())
ctx.uses[op[1].tempId()]++;
}
if (instr->isVOP3()) {
@ -2888,8 +2877,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
(ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16 : aco_opcode::v_mad_f16);
aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(mad_op, Format::VOP3A, 3, 1)};
for (unsigned i = 0; i < 3; i++)
{
for (unsigned i = 0; i < 3; i++) {
mad->operands[i] = op[i];
mad->neg[i] = neg[i];
mad->abs[i] = abs[i];