nir/lower_bit_size: Avoid round-trip conversion when possible

When we detect that the source is a conversion generated by the pass,
try to get the real source instead of doing a round-trip conversion.

Make sure that the nir_alu_type and the bit_size is the same between what
we need and what's before the detected conversion.

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35744>
This commit is contained in:
Romaric Jodin 2025-06-25 16:09:51 +02:00 committed by Marge Bot
parent f34ddff0bd
commit b4977a1605

View file

@ -48,6 +48,39 @@ convert_to_bit_size(nir_builder *bld, nir_def *src,
return nir_convert_to_bit_size(bld, src, type, bit_size);
}
static nir_def *
before_conversion(nir_builder *bld, nir_alu_type type, unsigned bit_size, nir_def *def, nir_op op)
{
/* Filtering in opcode of instruction where the LSB of the output are not affected by the MSB of the inputs */
switch (op) {
case nir_op_iadd:
case nir_op_iadd3:
case nir_op_iand:
case nir_op_imad:
case nir_op_imul:
case nir_op_ior:
case nir_op_ishl:
case nir_op_isub:
case nir_op_ixor:
case nir_op_mov:
break;
default:
return NULL;
}
if (def->parent_instr->type != nir_instr_type_alu) {
return NULL;
}
nir_alu_instr *alu_instr = nir_instr_as_alu(def->parent_instr);
if (alu_instr->op != nir_type_conversion_op((nir_alu_type)(type | bit_size),
(nir_alu_type)(type | def->bit_size),
nir_rounding_mode_undef) ||
alu_instr->src[0].src.ssa->bit_size != bit_size) {
return NULL;
}
/* Handle potential swizzling with 'nir_ssa_for_alu_src' which will be adding a move if needed */
return nir_ssa_for_alu_src(bld, alu_instr, 0);
}
static void
lower_alu_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size)
{
@ -62,8 +95,14 @@ lower_alu_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size)
nir_def *src = nir_ssa_for_alu_src(bld, alu, i);
nir_alu_type type = nir_op_infos[op].input_types[i];
if (nir_alu_type_get_type_size(type) == 0)
src = convert_to_bit_size(bld, src, type, bit_size);
if (nir_alu_type_get_type_size(type) == 0) {
nir_def *src_before_conversion = before_conversion(bld, type, bit_size, src, op);
if (src_before_conversion) {
src = src_before_conversion;
} else {
src = convert_to_bit_size(bld, src, type, bit_size);
}
}
if (i == 1 && (op == nir_op_ishl || op == nir_op_ishr || op == nir_op_ushr ||
op == nir_op_bitz || op == nir_op_bitz8 || op == nir_op_bitz16 ||