From 01134b0bfe407f43d8089551301ffedaeeb459ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Sch=C3=BCrmann?= Date: Wed, 2 Sep 2020 15:19:21 +0100 Subject: [PATCH] aco: simplify multiply-add combining When both operands of a v_sub (same apply for v_add) are mul and one already uses clamp/omod, pick the other operand to get a chance to combine to a MAD. No fossils-db changes. Co-authored-by: Samuel Pitoiset Reviewed-by: Rhys Perry Part-of: --- src/amd/compiler/aco_optimizer.cpp | 84 +++++++++++++----------------- 1 file changed, 36 insertions(+), 48 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index c986c6f69e7..967656c3c72 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -2797,49 +2797,50 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr if (need_fma && mad32 && !ctx.program->has_fast_fma32) return; - uint32_t uses_src0 = UINT32_MAX; - uint32_t uses_src1 = UINT32_MAX; Instruction* mul_instr = nullptr; unsigned add_op_idx; - /* check if any of the operands is a multiplication */ - ssa_info *op0_info = instr->operands[0].isTemp() ? &ctx.info[instr->operands[0].tempId()] : NULL; - ssa_info *op1_info = instr->operands[1].isTemp() ? &ctx.info[instr->operands[1].tempId()] : NULL; - if (op0_info && op0_info->is_mul() && (!need_fma || !op0_info->instr->definitions[0].isPrecise())) - uses_src0 = ctx.uses[instr->operands[0].tempId()]; - if (op1_info && op1_info->is_mul() && (!need_fma || !op1_info->instr->definitions[0].isPrecise())) - uses_src1 = ctx.uses[instr->operands[1].tempId()]; - + uint32_t uses = UINT32_MAX; /* find the 'best' mul instruction to combine with the add */ - if (uses_src0 < uses_src1) { - mul_instr = op0_info->instr; - add_op_idx = 1; - } else if (uses_src1 < uses_src0) { - mul_instr = op1_info->instr; - add_op_idx = 0; - } else if (uses_src0 != UINT32_MAX) { - /* tiebreaker: quite random what to pick */ - if (op0_info->instr->operands[0].isLiteral()) { - mul_instr = op1_info->instr; - add_op_idx = 0; - } else { - mul_instr = op0_info->instr; - add_op_idx = 1; - } + for (unsigned i = 0; i < 2; i++) { + if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul()) + continue; + /* check precision requirements */ + ssa_info& info = ctx.info[instr->operands[i].tempId()]; + if (need_fma && info.instr->definitions[0].isPrecise()) + continue; + + /* no clamp/omod allowed between mul and add */ + if (info.instr->isVOP3() && + (static_cast(info.instr)->clamp || + static_cast(info.instr)->omod)) + continue; + + Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]}; + if (info.instr->isSDWA() || + !check_vop3_operands(ctx, 3, op) || + ctx.uses[instr->operands[i].tempId()] >= uses) + continue; + + mul_instr = info.instr; + add_op_idx = 1 - i; + uses = ctx.uses[instr->operands[i].tempId()]; } + if (mul_instr) { - Operand op[3] = {Operand(v1), Operand(v1), Operand(v1)}; + /* turn mul+add into v_mad/v_fma */ + Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1], instr->operands[add_op_idx]}; + ctx.uses[mul_instr->definitions[0].tempId()]--; + if (ctx.uses[mul_instr->definitions[0].tempId()]) { + if (op[0].isTemp()) + ctx.uses[op[0].tempId()]++; + if (op[1].isTemp()) + ctx.uses[op[1].tempId()]++; + } + bool neg[3] = {false, false, false}; bool abs[3] = {false, false, false}; unsigned omod = 0; bool clamp = false; - op[0] = mul_instr->operands[0]; - op[1] = mul_instr->operands[1]; - op[2] = instr->operands[add_op_idx]; - // TODO: would be better to check this before selecting a mul instr? - if (!check_vop3_operands(ctx, 3, op)) - return; - if (mul_instr->isSDWA()) - return; if (mul_instr->isVOP3()) { VOP3A_instruction* vop3 = static_cast (mul_instr); @@ -2847,18 +2848,6 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr neg[1] = vop3->neg[1]; abs[0] = vop3->abs[0]; abs[1] = vop3->abs[1]; - /* we cannot use these modifiers between mul and add */ - if (vop3->clamp || vop3->omod) - return; - } - - /* convert to mad */ - ctx.uses[mul_instr->definitions[0].tempId()]--; - if (ctx.uses[mul_instr->definitions[0].tempId()]) { - if (op[0].isTemp()) - ctx.uses[op[0].tempId()]++; - if (op[1].isTemp()) - ctx.uses[op[1].tempId()]++; } if (instr->isVOP3()) { @@ -2888,8 +2877,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16 : aco_opcode::v_mad_f16); aco_ptr mad{create_instruction(mad_op, Format::VOP3A, 3, 1)}; - for (unsigned i = 0; i < 3; i++) - { + for (unsigned i = 0; i < 3; i++) { mad->operands[i] = op[i]; mad->neg[i] = neg[i]; mad->abs[i] = abs[i];