aco/optimizer: use new helpers to apply neg/abs to output of instructions

Foz-DB Navi21:
Totals from 6765 (6.93% of 97591) affected shaders:
MaxWaves: 134398 -> 134408 (+0.01%)
Instrs: 9775725 -> 9768079 (-0.08%); split: -0.08%, +0.01%
CodeSize: 50785228 -> 50777880 (-0.01%); split: -0.02%, +0.01%
VGPRs: 445840 -> 445784 (-0.01%)
SpillSGPRs: 14483 -> 14476 (-0.05%)
Latency: 40232431 -> 40230284 (-0.01%); split: -0.04%, +0.03%
InvThroughput: 10339051 -> 10329846 (-0.09%); split: -0.09%, +0.00%
VClause: 186785 -> 186788 (+0.00%); split: -0.01%, +0.01%
SClause: 157106 -> 157116 (+0.01%); split: -0.00%, +0.01%
Copies: 746817 -> 745378 (-0.19%); split: -0.26%, +0.07%
Branches: 189298 -> 189211 (-0.05%); split: -0.06%, +0.01%
PreSGPRs: 346169 -> 346158 (-0.00%)
PreVGPRs: 370712 -> 370660 (-0.01%); split: -0.02%, +0.00%
VALU: 6847295 -> 6839753 (-0.11%); split: -0.11%, +0.00%
SALU: 1139960 -> 1139942 (-0.00%); split: -0.00%, +0.00%

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38658>
This commit is contained in:
Georg Lehmann 2025-01-09 19:48:47 +01:00 committed by Marge Bot
parent 58f407702d
commit 4442064449

View file

@ -3921,6 +3921,66 @@ op_info_get_constant(opt_ctx& ctx, alu_opt_op op_info, aco_type type, uint64_t*
return true;
}
/* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
Instruction*
apply_output_mul(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
{
alu_opt_info info;
if (!alu_opt_gather_info(ctx, instr.get(), info))
return nullptr;
aco_type type = instr_info.alu_opcode_infos[(int)instr->opcode].def_types[0];
unsigned denorm_mode = type.bit_size == 32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64;
if (!ctx.info[parent->definitions[0].tempId()].is_canonicalized(type.bit_size) &&
denorm_mode != fp_denorm_keep)
return nullptr;
aco_type parent_type = instr_info.alu_opcode_infos[(int)parent->opcode].def_types[0];
if (type.num_components != parent_type.num_components || type.bit_size != parent_type.bit_size ||
instr->definitions[0].regClass().type() != parent->definitions[0].regClass().type())
return nullptr;
unsigned cidx = !info.operands[0].op.isConstant();
uint64_t constant = 0;
if (!op_info_get_constant(ctx, info.operands[cidx], type, &constant))
return nullptr;
for (unsigned i = 0; i < type.num_components; i++) {
double val = extract_float(constant, type.bit_size, i);
if (val < 0.0) {
val = fabs(val);
info.operands[!cidx].neg[i] ^= true;
}
if (val != 1.0)
return nullptr;
}
if ((info.omod || info.clamp) &&
!instr_info.alu_opcode_infos[(int)parent->opcode].output_modifiers)
return nullptr;
alu_opt_info parent_info;
if (!alu_opt_gather_info(ctx, parent, parent_info))
return nullptr;
if (parent_info.uses_insert() || (info.omod && (parent_info.omod || parent_info.clamp)))
return nullptr;
if (!backpropagate_input_modifiers(ctx, parent_info, info.operands[!cidx], type))
return nullptr;
parent_info.clamp |= info.clamp;
parent_info.omod |= info.omod;
parent_info.insert = info.insert;
parent_info.defs[0].setTemp(info.defs[0].getTemp());
if (!alu_opt_info_is_valid(ctx, parent_info))
return nullptr;
return alu_opt_info_to_instr(ctx, parent_info, parent);
}
Instruction*
apply_output_impl(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent)
{
@ -3935,6 +3995,9 @@ apply_output_impl(opt_ctx& ctx, aco_ptr<Instruction>& instr, Instruction* parent
return apply_s_not(ctx, instr, parent);
else if (instr->opcode == aco_opcode::s_abs_i32)
return apply_s_abs(ctx, instr, parent);
else if (instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64 ||
instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16)
return apply_output_mul(ctx, instr, parent);
else
UNREACHABLE("unhandled opcode");
@ -3949,7 +4012,11 @@ apply_output(opt_ctx& ctx, aco_ptr<Instruction>& instr)
case aco_opcode::v_not_b32:
case aco_opcode::s_not_b32:
case aco_opcode::s_not_b64:
case aco_opcode::s_abs_i32: break;
case aco_opcode::s_abs_i32:
case aco_opcode::v_mul_f64:
case aco_opcode::v_mul_f64_e64:
case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_f16: break;
default: return false;
}
@ -4195,21 +4262,6 @@ and_cb(opt_ctx& ctx, alu_opt_info& info)
return func1(ctx, info) && func2(ctx, info);
}
bool
is_mul(Instruction* instr)
{
switch (instr->opcode) {
case aco_opcode::v_mul_f64_e64:
case aco_opcode::v_mul_f64:
case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_legacy_f32:
case aco_opcode::v_mul_f16: return true;
case aco_opcode::v_fma_mix_f32:
return instr->operands[2].constantEquals(0) && instr->valu().neg[2];
default: return false;
}
}
void
combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
@ -4258,57 +4310,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
* The various comparison optimizations also currently only work with 32-bit
* floats. */
/* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
if ((ctx.info[instr->definitions[0].tempId()].label & input_mod_labels) &&
ctx.uses[ctx.info[instr->definitions[0].tempId()].temp.id()] == 1) {
Temp val = ctx.info[instr->definitions[0].tempId()].temp;
Instruction* mul_instr = ctx.info[val.id()].parent_instr;
if (!is_mul(mul_instr))
return;
if (mul_instr->operands[0].isLiteral())
return;
if (mul_instr->valu().clamp)
return;
if (mul_instr->isSDWA() || mul_instr->isDPP())
return;
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
mul_instr->definitions[0].isSZPreserve())
return;
if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
return;
/* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
ctx.uses[mul_instr->definitions[0].tempId()]--;
Definition def = instr->definitions[0];
bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg(def.bytes() * 8);
bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs(def.bytes() * 8);
uint32_t pass_flags = instr->pass_flags;
Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));
std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
instr->pass_flags = pass_flags;
instr->definitions[0] = def;
VALU_instruction& new_mul = instr->valu();
VALU_instruction& mul = mul_instr->valu();
new_mul.neg = mul.neg;
new_mul.abs = mul.abs;
new_mul.omod = mul.omod;
new_mul.opsel = mul.opsel;
new_mul.opsel_lo = mul.opsel_lo;
new_mul.opsel_hi = mul.opsel_hi;
if (is_abs) {
new_mul.neg[0] = new_mul.neg[1] = false;
new_mul.abs[0] = new_mul.abs[1] = true;
}
new_mul.neg[0] ^= is_neg;
new_mul.clamp = false;
ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get();
return;
}
alu_opt_info info;
if (!alu_opt_gather_info(ctx, instr.get(), info))
return;