aco/optimizer: add new helper functions for combining two instructions

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38150>
This commit is contained in:
Georg Lehmann 2024-11-26 16:08:01 +01:00 committed by Marge Bot
parent 87e168f223
commit 1e2aea7461

View file

@ -3065,6 +3065,347 @@ is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value
return false;
}
/* This function attempts to propagate (potential) input modifers from the consuming
* instruction backwards to the producing instruction.
* Because inbetween swizzles are resolved,
* it also changes num_components of the producer's operands to match consumer.
*
* - info is the instruction info of the producing instruction
* - op_info is the Operand info of the consuming instruction
* - type is the aco type of op_info
*/
bool
backpropagate_input_modifiers(opt_ctx& ctx, alu_opt_info& info, const alu_opt_op& op_info,
const aco_type& type)
{
if (op_info.f16_to_f32 || op_info.dpp16 || op_info.dpp8)
return false;
aco_type dest_type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0];
if (info.f32_to_f16)
dest_type.bit_size = 16;
if (info.uses_insert())
return false;
/* Resolve swizzles first. */
if (op_info.op.size() > 1) {
assert(type.num_components == 1);
} else {
bitarray8 swizzle = 0;
for (unsigned comp = 0; comp < type.num_components; comp++) {
/* Check if this extract is a swizzle or some other subdword access. */
if (op_info.extract[comp].offset() * 8 % type.bit_size != 0 ||
op_info.extract[comp].size() * 8 < type.bit_size)
return false;
swizzle[comp] = op_info.extract[comp].offset() * 8 / type.bit_size;
}
if (swizzle != 0 && dest_type.num_components == 1)
return false;
if (swizzle == 0b10) {
/* noop */
} else if (info.opcode == aco_opcode::v_cvt_pkrtz_f16_f32 ||
info.opcode == aco_opcode::v_cvt_pkrtz_f16_f32_e64 ||
info.opcode == aco_opcode::s_cvt_pk_rtz_f16_f32 ||
info.opcode == aco_opcode::v_pack_b32_f16) {
if (swizzle == 0b01) {
std::swap(info.operands[0], info.operands[1]);
} else {
unsigned broadcast = swizzle == 0b00 ? 0 : 1;
info.operands[!broadcast] = info.operands[broadcast];
}
} else {
for (alu_opt_op& op : info.operands) {
if (swizzle == 0b01) {
op.neg[0].swap(op.neg[1]);
op.abs[0].swap(op.abs[1]);
std::swap(op.extract[0], op.extract[1]);
} else {
unsigned broadcast = swizzle == 0b00 ? 0 : 1;
op.neg[!broadcast] = op.neg[broadcast];
op.abs[!broadcast] = op.abs[broadcast];
op.extract[!broadcast] = op.extract[broadcast];
}
}
}
}
if (!op_info.abs && !op_info.neg)
return true;
if (info.clamp || type.bit_size != dest_type.bit_size)
return false;
/* neg(omod(...)) and omod(neg(...)) are not the same because omod turn -0.0 into +0.0.
* Adds and dx9 mul have similar limitations.
*/
bool require_neg_nsz = info.omod;
/* Apply modifiers for each component. */
switch (info.opcode) {
case aco_opcode::v_mul_legacy_f32: require_neg_nsz = true; FALLTHROUGH;
case aco_opcode::v_mul_f64_e64:
case aco_opcode::v_mul_f64:
case aco_opcode::v_mul_f32:
case aco_opcode::v_mul_f16:
case aco_opcode::s_mul_f32:
case aco_opcode::s_mul_f16:
case aco_opcode::v_pk_mul_f16:
case aco_opcode::v_cvt_f32_f64:
case aco_opcode::v_cvt_f64_f32:
case aco_opcode::v_cvt_f16_f32:
case aco_opcode::v_cvt_f32_f16:
case aco_opcode::s_cvt_f16_f32:
case aco_opcode::s_cvt_f32_f16:
case aco_opcode::p_v_cvt_f16_f32_rtne:
case aco_opcode::p_s_cvt_f16_f32_rtne:
for (alu_opt_op& op : info.operands) {
op.neg &= ~op_info.abs;
op.abs |= op_info.abs;
}
info.operands[0].neg ^= op_info.neg;
break;
case aco_opcode::v_cndmask_b32:
case aco_opcode::v_cndmask_b16:
case aco_opcode::s_cselect_b32:
case aco_opcode::s_cselect_b64:
for (unsigned i = 0; i < 2; i++) {
info.operands[i].neg &= ~op_info.abs;
info.operands[i].abs |= op_info.abs;
info.operands[i].neg ^= op_info.neg;
}
break;
case aco_opcode::v_add_f64_e64:
case aco_opcode::v_add_f64:
case aco_opcode::v_add_f32:
case aco_opcode::v_add_f16:
case aco_opcode::s_add_f32:
case aco_opcode::s_add_f16:
case aco_opcode::v_pk_add_f16:
case aco_opcode::v_fma_f64:
case aco_opcode::v_fma_f32:
case aco_opcode::v_fma_f16:
case aco_opcode::s_fmac_f32:
case aco_opcode::s_fmac_f16:
case aco_opcode::v_pk_fma_f16:
case aco_opcode::v_fma_legacy_f32:
case aco_opcode::v_fma_legacy_f16:
case aco_opcode::v_mad_f32:
case aco_opcode::v_mad_f16:
case aco_opcode::v_mad_legacy_f32:
case aco_opcode::v_mad_legacy_f16:
if (op_info.abs)
return false;
info.operands[0].neg ^= op_info.neg;
info.operands.back().neg ^= op_info.neg;
require_neg_nsz = true;
break;
case aco_opcode::v_min_f64_e64:
case aco_opcode::v_min_f64:
case aco_opcode::v_min_f32:
case aco_opcode::v_min_f16:
case aco_opcode::v_max_f64_e64:
case aco_opcode::v_max_f64:
case aco_opcode::v_max_f32:
case aco_opcode::v_max_f16:
case aco_opcode::v_min3_f32:
case aco_opcode::v_min3_f16:
case aco_opcode::v_max3_f32:
case aco_opcode::v_max3_f16:
case aco_opcode::v_minmax_f32:
case aco_opcode::v_minmax_f16:
case aco_opcode::v_maxmin_f32:
case aco_opcode::v_maxmin_f16:
case aco_opcode::s_min_f32:
case aco_opcode::s_min_f16:
case aco_opcode::s_max_f32:
case aco_opcode::s_max_f16:
case aco_opcode::v_pk_min_f16:
case aco_opcode::v_pk_max_f16:
if (op_info.abs)
return false;
if (op_info.neg[0] != op_info.neg[type.num_components - 1])
return false;
for (alu_opt_op& op : info.operands)
op.neg ^= op_info.neg;
switch (info.opcode) {
case aco_opcode::v_min_f64_e64: info.opcode = aco_opcode::v_max_f64_e64; break;
case aco_opcode::v_min_f64: info.opcode = aco_opcode::v_max_f64; break;
case aco_opcode::v_min_f32: info.opcode = aco_opcode::v_max_f32; break;
case aco_opcode::v_min_f16: info.opcode = aco_opcode::v_max_f16; break;
case aco_opcode::v_max_f64_e64: info.opcode = aco_opcode::v_min_f64_e64; break;
case aco_opcode::v_max_f64: info.opcode = aco_opcode::v_min_f64; break;
case aco_opcode::v_max_f32: info.opcode = aco_opcode::v_min_f32; break;
case aco_opcode::v_max_f16: info.opcode = aco_opcode::v_min_f16; break;
case aco_opcode::v_min3_f32: info.opcode = aco_opcode::v_max3_f32; break;
case aco_opcode::v_min3_f16: info.opcode = aco_opcode::v_max3_f16; break;
case aco_opcode::v_max3_f32: info.opcode = aco_opcode::v_min3_f32; break;
case aco_opcode::v_max3_f16: info.opcode = aco_opcode::v_min3_f16; break;
case aco_opcode::v_minmax_f32: info.opcode = aco_opcode::v_maxmin_f32; break;
case aco_opcode::v_minmax_f16: info.opcode = aco_opcode::v_maxmin_f16; break;
case aco_opcode::v_maxmin_f32: info.opcode = aco_opcode::v_minmax_f32; break;
case aco_opcode::v_maxmin_f16: info.opcode = aco_opcode::v_minmax_f16; break;
case aco_opcode::s_min_f32: info.opcode = aco_opcode::s_max_f32; break;
case aco_opcode::s_min_f16: info.opcode = aco_opcode::s_max_f16; break;
case aco_opcode::s_max_f32: info.opcode = aco_opcode::s_min_f32; break;
case aco_opcode::s_max_f16: info.opcode = aco_opcode::s_min_f16; break;
case aco_opcode::v_pk_min_f16: info.opcode = aco_opcode::v_pk_max_f16; break;
case aco_opcode::v_pk_max_f16: info.opcode = aco_opcode::v_pk_min_f16; break;
default: UNREACHABLE("invalid op");
}
break;
case aco_opcode::v_cvt_pkrtz_f16_f32:
case aco_opcode::v_cvt_pkrtz_f16_f32_e64:
case aco_opcode::s_cvt_pk_rtz_f16_f32:
case aco_opcode::v_pack_b32_f16:
for (unsigned comp = 0; comp < type.num_components; comp++) {
if (op_info.abs[comp]) {
info.operands[comp].neg[0] = false;
info.operands[comp].abs[0] = true;
}
info.operands[comp].neg[0] ^= op_info.neg[comp];
}
break;
default: return false;
}
if (op_info.neg && require_neg_nsz && info.defs[0].isSZPreserve())
return false;
return true;
}
typedef bool (*combine_instr_callback)(opt_ctx& ctx, alu_opt_info& info);
struct combine_instr_pattern {
aco_opcode src_opcode;
aco_opcode res_opcode;
unsigned operand_mask;
const char* swizzle;
combine_instr_callback callback;
};
bool
can_match_op(opt_ctx& ctx, Operand op, uint32_t exec_id)
{
if (!op.isTemp())
return false;
Instruction* op_instr = ctx.info[op.tempId()].parent_instr;
if (op_instr->definitions[0].getTemp() != op.getTemp())
return false;
if (op_instr->pass_flags == exec_id)
return true;
if (op_instr->isDPP() || op_instr->isVINTERP_INREG() || op_instr->reads_exec())
return false;
return true;
}
bool
match_and_apply_patterns(opt_ctx& ctx, alu_opt_info& info,
const aco::small_vec<combine_instr_pattern, 8>& patterns)
{
if (patterns.empty())
return false;
unsigned total_mask = 0;
for (const combine_instr_pattern& pattern : patterns)
total_mask |= pattern.operand_mask;
for (unsigned i = 0; i < info.operands.size(); i++) {
if (!can_match_op(ctx, info.operands[i].op, info.pass_flags))
total_mask &= ~BITFIELD_BIT(i);
}
if (!total_mask)
return false;
aco::small_vec<int, 4> indices;
indices.reserve(util_bitcount(total_mask));
u_foreach_bit (i, total_mask)
indices.push_back(i);
std::stable_sort(indices.begin(), indices.end(),
[&](int a, int b)
{
Temp temp_a = info.operands[a].op.getTemp();
Temp temp_b = info.operands[b].op.getTemp();
/* Less uses make it more likely/profitable to eliminate an instruction. */
if (ctx.uses[temp_a.id()] != ctx.uses[temp_b.id()])
return ctx.uses[temp_a.id()] < ctx.uses[temp_b.id()];
/* Prefer eliminating VALU instructions. */
if (temp_a.type() != temp_b.type())
return temp_a.type() == RegType::vgpr;
/* The id is a good approximation for instruction order,
* prefer instructions closer to info to not increase register pressure
* as much.
*/
return temp_a.id() > temp_b.id();
});
for (unsigned op_idx : indices) {
Temp tmp = info.operands[op_idx].op.getTemp();
alu_opt_info op_instr;
if (!alu_opt_gather_info(ctx, ctx.info[tmp.id()].parent_instr, op_instr))
continue;
if (op_instr.clamp || op_instr.omod || op_instr.f32_to_f16)
continue;
aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].op_types[op_idx];
if (!backpropagate_input_modifiers(ctx, op_instr, info.operands[op_idx], type))
continue;
for (const combine_instr_pattern& pattern : patterns) {
if (!(pattern.operand_mask & BITFIELD_BIT(op_idx)) ||
op_instr.opcode != pattern.src_opcode)
continue;
alu_opt_info new_info = info;
unsigned rem = info.operands.size() - 1;
unsigned op_count = rem + op_instr.operands.size();
new_info.operands.resize(op_count);
assert(strlen(pattern.swizzle) == op_count);
for (unsigned i = 0; i < op_count; i++) {
unsigned src_idx = pattern.swizzle[i] - '0';
if (src_idx < op_idx)
new_info.operands[i] = info.operands[src_idx];
else if (src_idx < rem)
new_info.operands[i] = info.operands[src_idx + 1];
else
new_info.operands[i] = op_instr.operands[src_idx - rem];
}
new_info.opcode = pattern.res_opcode;
if (op_instr.defs[0].isPrecise())
new_info.defs[0].setPrecise(true);
if (pattern.callback && !pattern.callback(ctx, new_info))
continue;
if (alu_opt_info_is_valid(ctx, new_info)) {
info = std::move(new_info);
return true;
}
}
}
return false;
}
/* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */
bool
combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)