aco/optimizer: add seperate fp16 abs/neg/fcanonicalize labels

In the future, we can't use the register class to detect fp16 vs fp32
because SALU uses s1 for both.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37867>
This commit is contained in:
Georg Lehmann 2025-10-21 15:45:21 +02:00 committed by Marge Bot
parent 9e9d9c0373
commit c61ee32034

View file

@ -55,9 +55,17 @@ enum Label {
label_scc_needed = 1ull << 6,
label_extract = 1ull << 7,
label_abs = 1ull << 16,
label_neg = 1ull << 17,
label_fcanonicalize = 1ull << 18,
/* These have one label for fp16 and one for fp32/64.
* 32bit vs 64bit type mismatches are impossible because
* of the different register class sizes.
*/
label_abs_fp32_64 = 1ull << 16,
label_neg_fp32_64 = 1ull << 17,
label_fcanonicalize_fp32_64 = 1ull << 18,
label_abs_fp16 = 1ull << 19,
label_neg_fp16 = 1ull << 20,
label_fcanonicalize_fp16 = 1ull << 21,
label_canonicalized = 1ull << 22,
/* label_{omod2,omod4,omod5,clamp} are used for both 16 and
@ -77,9 +85,12 @@ enum Label {
static constexpr uint64_t instr_mod_labels =
label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f |
label_uniform_bool | label_scc_invert | label_b2i |
label_fcanonicalize;
static constexpr uint64_t input_mod_labels =
label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64;
static constexpr uint64_t temp_labels = label_temp | label_uniform_bool | label_scc_invert |
label_b2f | label_b2i | input_mod_labels |
label_fcanonicalize_fp32_64 | label_fcanonicalize_fp16;
static constexpr uint64_t val_labels = label_constant | label_mad;
@ -126,25 +137,39 @@ struct ssa_info {
bool is_constant() { return label & label_constant; }
void set_abs(Temp abs_temp)
void set_abs(Temp abs_temp, unsigned bit_size)
{
add_label(label_abs);
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
add_label(bit_size == 16 ? label_abs_fp16 : label_abs_fp32_64);
temp = abs_temp;
}
bool is_abs() { return label & label_abs; }
void set_neg(Temp neg_temp)
bool is_abs(unsigned bit_size)
{
add_label(label_neg);
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
return bit_size == 16 ? label & label_abs_fp16 : label & label_abs_fp32_64;
}
void set_neg(Temp neg_temp, unsigned bit_size)
{
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
add_label(bit_size == 16 ? label_neg_fp16 : label_neg_fp32_64);
temp = neg_temp;
}
bool is_neg() { return label & label_neg; }
void set_neg_abs(Temp neg_abs_temp)
bool is_neg(unsigned bit_size)
{
add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
return bit_size == 16 ? label & label_neg_fp16 : label & label_neg_fp32_64;
}
void set_neg_abs(Temp neg_abs_temp, unsigned bit_size)
{
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
if (bit_size == 16)
add_label((Label)((uint32_t)label_abs_fp16 | (uint32_t)label_neg_fp16));
else
add_label((Label)((uint32_t)label_abs_fp32_64 | (uint32_t)label_neg_fp32_64));
temp = neg_abs_temp;
}
@ -254,13 +279,19 @@ struct ssa_info {
bool is_b2i() { return label & label_b2i; }
void set_fcanonicalize(Temp tmp)
void set_fcanonicalize(Temp tmp, unsigned bit_size)
{
add_label(label_fcanonicalize);
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
add_label(bit_size == 16 ? label_fcanonicalize_fp16 : label_fcanonicalize_fp32_64);
temp = tmp;
}
bool is_fcanonicalize() { return label & label_fcanonicalize; }
bool is_fcanonicalize(unsigned bit_size)
{
assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
return bit_size == 16 ? label & label_fcanonicalize_fp16
: label & label_fcanonicalize_fp32_64;
}
void set_canonicalized() { add_label(label_canonicalized); }
@ -2002,19 +2033,34 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type)
return true;
}
// TODO use parent dst type
if (info.is_fcanonicalize() || info.is_abs() || info.is_neg()) {
if (ctx.info[info.temp.id()].is_canonicalized() ||
(tmp.bytes() == 4 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == fp_denorm_keep)
type.base_type = aco_base_type_uint;
else
type.base_type = aco_base_type_float;
} else {
type.base_type = aco_base_type_uint;
for (unsigned bit_size = tmp.size() == 2 ? 64 : 16; bit_size <= tmp.bytes() * 8; bit_size *= 2) {
if (info.is_fcanonicalize(bit_size) || info.is_abs(bit_size) || info.is_neg(bit_size)) {
type.num_components = 1;
type.bit_size = bit_size;
if (ctx.info[info.temp.id()].is_canonicalized() ||
(bit_size == 32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == fp_denorm_keep)
type.base_type = aco_base_type_uint;
else
type.base_type = aco_base_type_float;
op_info.op = Operand(info.temp);
if (info.is_abs(bit_size))
op_info.abs[0] = true;
if (info.is_neg(bit_size))
op_info.neg[0] = true;
return true;
}
}
type.base_type = aco_base_type_uint;
type.num_components = 1;
type.bit_size = tmp.bytes() * 8;
if (info.is_temp()) {
op_info.op = Operand(info.temp);
return true;
}
if (info.is_extract()) {
op_info.extract[0] = parse_extract(info.parent_instr);
op_info.op = info.parent_instr->operands[0];
@ -2052,14 +2098,6 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type)
return true;
}
if (info.is_temp() || info.is_fcanonicalize() || info.is_abs() || info.is_neg()) {
op_info.op = Operand(info.temp);
if (info.is_abs())
op_info.abs[0] = true;
if (info.is_neg())
op_info.neg[0] = true;
return true;
}
return false;
}
@ -2723,6 +2761,7 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* TODO: try to move the negate/abs modifier to the consumer instead */
bool uses_mods = instr->usesModifiers();
bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
unsigned bit_size = fp16 ? 16 : 32;
unsigned denorm_mode = fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32;
for (unsigned i = 0; i < 2; i++) {
@ -2747,16 +2786,16 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
}
if (abs && neg && other.type() == RegType::vgpr)
ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
ctx.info[instr->definitions[0].tempId()].set_neg_abs(other, bit_size);
else if (abs && !neg && other.type() == RegType::vgpr)
ctx.info[instr->definitions[0].tempId()].set_abs(other);
ctx.info[instr->definitions[0].tempId()].set_abs(other, bit_size);
else if (!abs && neg && other.type() == RegType::vgpr)
ctx.info[instr->definitions[0].tempId()].set_neg(other);
ctx.info[instr->definitions[0].tempId()].set_neg(other, bit_size);
else if (!abs && !neg) {
if (denorm_mode == fp_denorm_keep || ctx.info[other.id()].is_canonicalized())
ctx.info[instr->definitions[0].tempId()].set_temp(other);
else
ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other, bit_size);
}
} else if (uses_mods || (instr->definitions[0].isSZPreserve() &&
instr->opcode != aco_opcode::v_mul_legacy_f32)) {
@ -4444,7 +4483,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
* floats. */
/* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
if ((ctx.info[instr->definitions[0].tempId()].label & input_mod_labels) &&
ctx.uses[ctx.info[instr->definitions[0].tempId()].temp.id()] == 1) {
Temp val = ctx.info[instr->definitions[0].tempId()].temp;
Instruction* mul_instr = ctx.info[val.id()].parent_instr;
@ -4467,8 +4506,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
/* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
ctx.uses[mul_instr->definitions[0].tempId()]--;
Definition def = instr->definitions[0];
bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg(def.bytes() * 8);
bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs(def.bytes() * 8);
uint32_t pass_flags = instr->pass_flags;
Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));