mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-24 17:30:12 +01:00
aco/optimizer: add new helper functions for combining two instructions
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38150>
This commit is contained in:
parent
87e168f223
commit
1e2aea7461
1 changed files with 341 additions and 0 deletions
|
|
@ -3065,6 +3065,347 @@ is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value
|
|||
return false;
|
||||
}
|
||||
|
||||
/* This function attempts to propagate (potential) input modifers from the consuming
|
||||
* instruction backwards to the producing instruction.
|
||||
* Because inbetween swizzles are resolved,
|
||||
* it also changes num_components of the producer's operands to match consumer.
|
||||
*
|
||||
* - info is the instruction info of the producing instruction
|
||||
* - op_info is the Operand info of the consuming instruction
|
||||
* - type is the aco type of op_info
|
||||
*/
|
||||
bool
|
||||
backpropagate_input_modifiers(opt_ctx& ctx, alu_opt_info& info, const alu_opt_op& op_info,
|
||||
const aco_type& type)
|
||||
{
|
||||
if (op_info.f16_to_f32 || op_info.dpp16 || op_info.dpp8)
|
||||
return false;
|
||||
|
||||
aco_type dest_type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0];
|
||||
|
||||
if (info.f32_to_f16)
|
||||
dest_type.bit_size = 16;
|
||||
|
||||
if (info.uses_insert())
|
||||
return false;
|
||||
|
||||
/* Resolve swizzles first. */
|
||||
if (op_info.op.size() > 1) {
|
||||
assert(type.num_components == 1);
|
||||
} else {
|
||||
bitarray8 swizzle = 0;
|
||||
for (unsigned comp = 0; comp < type.num_components; comp++) {
|
||||
/* Check if this extract is a swizzle or some other subdword access. */
|
||||
if (op_info.extract[comp].offset() * 8 % type.bit_size != 0 ||
|
||||
op_info.extract[comp].size() * 8 < type.bit_size)
|
||||
return false;
|
||||
swizzle[comp] = op_info.extract[comp].offset() * 8 / type.bit_size;
|
||||
}
|
||||
|
||||
if (swizzle != 0 && dest_type.num_components == 1)
|
||||
return false;
|
||||
|
||||
if (swizzle == 0b10) {
|
||||
/* noop */
|
||||
} else if (info.opcode == aco_opcode::v_cvt_pkrtz_f16_f32 ||
|
||||
info.opcode == aco_opcode::v_cvt_pkrtz_f16_f32_e64 ||
|
||||
info.opcode == aco_opcode::s_cvt_pk_rtz_f16_f32 ||
|
||||
info.opcode == aco_opcode::v_pack_b32_f16) {
|
||||
if (swizzle == 0b01) {
|
||||
std::swap(info.operands[0], info.operands[1]);
|
||||
} else {
|
||||
unsigned broadcast = swizzle == 0b00 ? 0 : 1;
|
||||
info.operands[!broadcast] = info.operands[broadcast];
|
||||
}
|
||||
} else {
|
||||
for (alu_opt_op& op : info.operands) {
|
||||
if (swizzle == 0b01) {
|
||||
op.neg[0].swap(op.neg[1]);
|
||||
op.abs[0].swap(op.abs[1]);
|
||||
std::swap(op.extract[0], op.extract[1]);
|
||||
} else {
|
||||
unsigned broadcast = swizzle == 0b00 ? 0 : 1;
|
||||
op.neg[!broadcast] = op.neg[broadcast];
|
||||
op.abs[!broadcast] = op.abs[broadcast];
|
||||
op.extract[!broadcast] = op.extract[broadcast];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!op_info.abs && !op_info.neg)
|
||||
return true;
|
||||
|
||||
if (info.clamp || type.bit_size != dest_type.bit_size)
|
||||
return false;
|
||||
|
||||
/* neg(omod(...)) and omod(neg(...)) are not the same because omod turn -0.0 into +0.0.
|
||||
* Adds and dx9 mul have similar limitations.
|
||||
*/
|
||||
bool require_neg_nsz = info.omod;
|
||||
|
||||
/* Apply modifiers for each component. */
|
||||
switch (info.opcode) {
|
||||
case aco_opcode::v_mul_legacy_f32: require_neg_nsz = true; FALLTHROUGH;
|
||||
case aco_opcode::v_mul_f64_e64:
|
||||
case aco_opcode::v_mul_f64:
|
||||
case aco_opcode::v_mul_f32:
|
||||
case aco_opcode::v_mul_f16:
|
||||
case aco_opcode::s_mul_f32:
|
||||
case aco_opcode::s_mul_f16:
|
||||
case aco_opcode::v_pk_mul_f16:
|
||||
case aco_opcode::v_cvt_f32_f64:
|
||||
case aco_opcode::v_cvt_f64_f32:
|
||||
case aco_opcode::v_cvt_f16_f32:
|
||||
case aco_opcode::v_cvt_f32_f16:
|
||||
case aco_opcode::s_cvt_f16_f32:
|
||||
case aco_opcode::s_cvt_f32_f16:
|
||||
case aco_opcode::p_v_cvt_f16_f32_rtne:
|
||||
case aco_opcode::p_s_cvt_f16_f32_rtne:
|
||||
for (alu_opt_op& op : info.operands) {
|
||||
op.neg &= ~op_info.abs;
|
||||
op.abs |= op_info.abs;
|
||||
}
|
||||
info.operands[0].neg ^= op_info.neg;
|
||||
break;
|
||||
case aco_opcode::v_cndmask_b32:
|
||||
case aco_opcode::v_cndmask_b16:
|
||||
case aco_opcode::s_cselect_b32:
|
||||
case aco_opcode::s_cselect_b64:
|
||||
for (unsigned i = 0; i < 2; i++) {
|
||||
info.operands[i].neg &= ~op_info.abs;
|
||||
info.operands[i].abs |= op_info.abs;
|
||||
info.operands[i].neg ^= op_info.neg;
|
||||
}
|
||||
break;
|
||||
case aco_opcode::v_add_f64_e64:
|
||||
case aco_opcode::v_add_f64:
|
||||
case aco_opcode::v_add_f32:
|
||||
case aco_opcode::v_add_f16:
|
||||
case aco_opcode::s_add_f32:
|
||||
case aco_opcode::s_add_f16:
|
||||
case aco_opcode::v_pk_add_f16:
|
||||
case aco_opcode::v_fma_f64:
|
||||
case aco_opcode::v_fma_f32:
|
||||
case aco_opcode::v_fma_f16:
|
||||
case aco_opcode::s_fmac_f32:
|
||||
case aco_opcode::s_fmac_f16:
|
||||
case aco_opcode::v_pk_fma_f16:
|
||||
case aco_opcode::v_fma_legacy_f32:
|
||||
case aco_opcode::v_fma_legacy_f16:
|
||||
case aco_opcode::v_mad_f32:
|
||||
case aco_opcode::v_mad_f16:
|
||||
case aco_opcode::v_mad_legacy_f32:
|
||||
case aco_opcode::v_mad_legacy_f16:
|
||||
if (op_info.abs)
|
||||
return false;
|
||||
info.operands[0].neg ^= op_info.neg;
|
||||
info.operands.back().neg ^= op_info.neg;
|
||||
require_neg_nsz = true;
|
||||
break;
|
||||
case aco_opcode::v_min_f64_e64:
|
||||
case aco_opcode::v_min_f64:
|
||||
case aco_opcode::v_min_f32:
|
||||
case aco_opcode::v_min_f16:
|
||||
case aco_opcode::v_max_f64_e64:
|
||||
case aco_opcode::v_max_f64:
|
||||
case aco_opcode::v_max_f32:
|
||||
case aco_opcode::v_max_f16:
|
||||
case aco_opcode::v_min3_f32:
|
||||
case aco_opcode::v_min3_f16:
|
||||
case aco_opcode::v_max3_f32:
|
||||
case aco_opcode::v_max3_f16:
|
||||
case aco_opcode::v_minmax_f32:
|
||||
case aco_opcode::v_minmax_f16:
|
||||
case aco_opcode::v_maxmin_f32:
|
||||
case aco_opcode::v_maxmin_f16:
|
||||
case aco_opcode::s_min_f32:
|
||||
case aco_opcode::s_min_f16:
|
||||
case aco_opcode::s_max_f32:
|
||||
case aco_opcode::s_max_f16:
|
||||
case aco_opcode::v_pk_min_f16:
|
||||
case aco_opcode::v_pk_max_f16:
|
||||
if (op_info.abs)
|
||||
return false;
|
||||
|
||||
if (op_info.neg[0] != op_info.neg[type.num_components - 1])
|
||||
return false;
|
||||
for (alu_opt_op& op : info.operands)
|
||||
op.neg ^= op_info.neg;
|
||||
|
||||
switch (info.opcode) {
|
||||
case aco_opcode::v_min_f64_e64: info.opcode = aco_opcode::v_max_f64_e64; break;
|
||||
case aco_opcode::v_min_f64: info.opcode = aco_opcode::v_max_f64; break;
|
||||
case aco_opcode::v_min_f32: info.opcode = aco_opcode::v_max_f32; break;
|
||||
case aco_opcode::v_min_f16: info.opcode = aco_opcode::v_max_f16; break;
|
||||
case aco_opcode::v_max_f64_e64: info.opcode = aco_opcode::v_min_f64_e64; break;
|
||||
case aco_opcode::v_max_f64: info.opcode = aco_opcode::v_min_f64; break;
|
||||
case aco_opcode::v_max_f32: info.opcode = aco_opcode::v_min_f32; break;
|
||||
case aco_opcode::v_max_f16: info.opcode = aco_opcode::v_min_f16; break;
|
||||
case aco_opcode::v_min3_f32: info.opcode = aco_opcode::v_max3_f32; break;
|
||||
case aco_opcode::v_min3_f16: info.opcode = aco_opcode::v_max3_f16; break;
|
||||
case aco_opcode::v_max3_f32: info.opcode = aco_opcode::v_min3_f32; break;
|
||||
case aco_opcode::v_max3_f16: info.opcode = aco_opcode::v_min3_f16; break;
|
||||
case aco_opcode::v_minmax_f32: info.opcode = aco_opcode::v_maxmin_f32; break;
|
||||
case aco_opcode::v_minmax_f16: info.opcode = aco_opcode::v_maxmin_f16; break;
|
||||
case aco_opcode::v_maxmin_f32: info.opcode = aco_opcode::v_minmax_f32; break;
|
||||
case aco_opcode::v_maxmin_f16: info.opcode = aco_opcode::v_minmax_f16; break;
|
||||
case aco_opcode::s_min_f32: info.opcode = aco_opcode::s_max_f32; break;
|
||||
case aco_opcode::s_min_f16: info.opcode = aco_opcode::s_max_f16; break;
|
||||
case aco_opcode::s_max_f32: info.opcode = aco_opcode::s_min_f32; break;
|
||||
case aco_opcode::s_max_f16: info.opcode = aco_opcode::s_min_f16; break;
|
||||
case aco_opcode::v_pk_min_f16: info.opcode = aco_opcode::v_pk_max_f16; break;
|
||||
case aco_opcode::v_pk_max_f16: info.opcode = aco_opcode::v_pk_min_f16; break;
|
||||
default: UNREACHABLE("invalid op");
|
||||
}
|
||||
break;
|
||||
case aco_opcode::v_cvt_pkrtz_f16_f32:
|
||||
case aco_opcode::v_cvt_pkrtz_f16_f32_e64:
|
||||
case aco_opcode::s_cvt_pk_rtz_f16_f32:
|
||||
case aco_opcode::v_pack_b32_f16:
|
||||
for (unsigned comp = 0; comp < type.num_components; comp++) {
|
||||
if (op_info.abs[comp]) {
|
||||
info.operands[comp].neg[0] = false;
|
||||
info.operands[comp].abs[0] = true;
|
||||
}
|
||||
info.operands[comp].neg[0] ^= op_info.neg[comp];
|
||||
}
|
||||
break;
|
||||
default: return false;
|
||||
}
|
||||
|
||||
if (op_info.neg && require_neg_nsz && info.defs[0].isSZPreserve())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
typedef bool (*combine_instr_callback)(opt_ctx& ctx, alu_opt_info& info);
|
||||
|
||||
struct combine_instr_pattern {
|
||||
aco_opcode src_opcode;
|
||||
aco_opcode res_opcode;
|
||||
unsigned operand_mask;
|
||||
const char* swizzle;
|
||||
combine_instr_callback callback;
|
||||
};
|
||||
|
||||
bool
|
||||
can_match_op(opt_ctx& ctx, Operand op, uint32_t exec_id)
|
||||
{
|
||||
if (!op.isTemp())
|
||||
return false;
|
||||
|
||||
Instruction* op_instr = ctx.info[op.tempId()].parent_instr;
|
||||
if (op_instr->definitions[0].getTemp() != op.getTemp())
|
||||
return false;
|
||||
|
||||
if (op_instr->pass_flags == exec_id)
|
||||
return true;
|
||||
|
||||
if (op_instr->isDPP() || op_instr->isVINTERP_INREG() || op_instr->reads_exec())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
match_and_apply_patterns(opt_ctx& ctx, alu_opt_info& info,
|
||||
const aco::small_vec<combine_instr_pattern, 8>& patterns)
|
||||
{
|
||||
if (patterns.empty())
|
||||
return false;
|
||||
|
||||
unsigned total_mask = 0;
|
||||
for (const combine_instr_pattern& pattern : patterns)
|
||||
total_mask |= pattern.operand_mask;
|
||||
|
||||
for (unsigned i = 0; i < info.operands.size(); i++) {
|
||||
if (!can_match_op(ctx, info.operands[i].op, info.pass_flags))
|
||||
total_mask &= ~BITFIELD_BIT(i);
|
||||
}
|
||||
|
||||
if (!total_mask)
|
||||
return false;
|
||||
|
||||
aco::small_vec<int, 4> indices;
|
||||
indices.reserve(util_bitcount(total_mask));
|
||||
u_foreach_bit (i, total_mask)
|
||||
indices.push_back(i);
|
||||
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&](int a, int b)
|
||||
{
|
||||
Temp temp_a = info.operands[a].op.getTemp();
|
||||
Temp temp_b = info.operands[b].op.getTemp();
|
||||
|
||||
/* Less uses make it more likely/profitable to eliminate an instruction. */
|
||||
if (ctx.uses[temp_a.id()] != ctx.uses[temp_b.id()])
|
||||
return ctx.uses[temp_a.id()] < ctx.uses[temp_b.id()];
|
||||
|
||||
/* Prefer eliminating VALU instructions. */
|
||||
if (temp_a.type() != temp_b.type())
|
||||
return temp_a.type() == RegType::vgpr;
|
||||
|
||||
/* The id is a good approximation for instruction order,
|
||||
* prefer instructions closer to info to not increase register pressure
|
||||
* as much.
|
||||
*/
|
||||
return temp_a.id() > temp_b.id();
|
||||
});
|
||||
|
||||
for (unsigned op_idx : indices) {
|
||||
Temp tmp = info.operands[op_idx].op.getTemp();
|
||||
alu_opt_info op_instr;
|
||||
if (!alu_opt_gather_info(ctx, ctx.info[tmp.id()].parent_instr, op_instr))
|
||||
continue;
|
||||
|
||||
if (op_instr.clamp || op_instr.omod || op_instr.f32_to_f16)
|
||||
continue;
|
||||
|
||||
aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].op_types[op_idx];
|
||||
if (!backpropagate_input_modifiers(ctx, op_instr, info.operands[op_idx], type))
|
||||
continue;
|
||||
|
||||
for (const combine_instr_pattern& pattern : patterns) {
|
||||
if (!(pattern.operand_mask & BITFIELD_BIT(op_idx)) ||
|
||||
op_instr.opcode != pattern.src_opcode)
|
||||
continue;
|
||||
|
||||
alu_opt_info new_info = info;
|
||||
|
||||
unsigned rem = info.operands.size() - 1;
|
||||
unsigned op_count = rem + op_instr.operands.size();
|
||||
new_info.operands.resize(op_count);
|
||||
assert(strlen(pattern.swizzle) == op_count);
|
||||
for (unsigned i = 0; i < op_count; i++) {
|
||||
unsigned src_idx = pattern.swizzle[i] - '0';
|
||||
if (src_idx < op_idx)
|
||||
new_info.operands[i] = info.operands[src_idx];
|
||||
else if (src_idx < rem)
|
||||
new_info.operands[i] = info.operands[src_idx + 1];
|
||||
else
|
||||
new_info.operands[i] = op_instr.operands[src_idx - rem];
|
||||
}
|
||||
|
||||
new_info.opcode = pattern.res_opcode;
|
||||
|
||||
if (op_instr.defs[0].isPrecise())
|
||||
new_info.defs[0].setPrecise(true);
|
||||
|
||||
if (pattern.callback && !pattern.callback(ctx, new_info))
|
||||
continue;
|
||||
|
||||
if (alu_opt_info_is_valid(ctx, new_info)) {
|
||||
info = std::move(new_info);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */
|
||||
bool
|
||||
combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue