nir/load_store_vectorize: also parse offsets through u2u64 if additions don't wrap around

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37163>
This commit is contained in:
Daniel Schürmann 2025-07-17 11:33:34 +02:00 committed by Marge Bot
parent 084add9959
commit acb47d2c78

View file

@ -305,11 +305,14 @@ get_effective_alu_op(nir_scalar scalar)
* sources is a constant, update "def" to be the non-constant source, fill "c" * sources is a constant, update "def" to be the non-constant source, fill "c"
* with the constant and return true. */ * with the constant and return true. */
static bool static bool
parse_alu(nir_scalar *def, nir_op op, uint64_t *c) parse_alu(nir_scalar *def, nir_op op, uint64_t *c, bool require_nuw)
{ {
if (!nir_scalar_is_alu(*def) || get_effective_alu_op(*def) != op) if (!nir_scalar_is_alu(*def) || get_effective_alu_op(*def) != op)
return false; return false;
if (require_nuw && !nir_def_as_alu(def->def)->no_unsigned_wrap)
return false;
nir_scalar src0 = nir_scalar_chase_alu_src(*def, 0); nir_scalar src0 = nir_scalar_chase_alu_src(*def, 0);
nir_scalar src1 = nir_scalar_chase_alu_src(*def, 1); nir_scalar src1 = nir_scalar_chase_alu_src(*def, 1);
if (op != nir_op_ishl && nir_scalar_is_const(src0)) { if (op != nir_op_ishl && nir_scalar_is_const(src0)) {
@ -337,20 +340,23 @@ parse_offset(nir_scalar *base, uint64_t *base_mul, uint64_t *offset)
uint64_t mul = 1; uint64_t mul = 1;
uint64_t add = 0; uint64_t add = 0;
bool progress = false; bool progress = false;
bool require_nuw = false;
do { do {
uint64_t mul2 = 1, add2 = 0; uint64_t mul2 = 1, add2 = 0;
progress = parse_alu(base, nir_op_imul, &mul2); progress = parse_alu(base, nir_op_imul, &mul2, require_nuw);
mul *= mul2; mul *= mul2;
mul2 = 0; mul2 = 0;
progress |= parse_alu(base, nir_op_ishl, &mul2); progress |= parse_alu(base, nir_op_ishl, &mul2, require_nuw);
mul <<= mul2; mul <<= mul2;
progress |= parse_alu(base, nir_op_iadd, &add2); progress |= parse_alu(base, nir_op_iadd, &add2, require_nuw);
add += add2 * mul; add += add2 * mul;
if (nir_scalar_is_alu(*base) && nir_scalar_alu_op(*base) == nir_op_mov) { if (nir_scalar_is_alu(*base) && (nir_scalar_alu_op(*base) == nir_op_mov ||
nir_scalar_alu_op(*base) == nir_op_u2u64)) {
require_nuw |= nir_scalar_alu_op(*base) == nir_op_u2u64;
*base = nir_scalar_chase_alu_src(*base, 0); *base = nir_scalar_chase_alu_src(*base, 0);
progress = true; progress = true;
} }
@ -384,8 +390,6 @@ static unsigned
add_to_entry_key(nir_scalar *offset_defs, uint64_t *offset_defs_mul, add_to_entry_key(nir_scalar *offset_defs, uint64_t *offset_defs_mul,
unsigned offset_def_count, nir_scalar def, uint64_t mul) unsigned offset_def_count, nir_scalar def, uint64_t mul)
{ {
mul = util_mask_sign_extend(mul, def.def->bit_size);
for (unsigned i = 0; i <= offset_def_count; i++) { for (unsigned i = 0; i <= offset_def_count; i++) {
if (i == offset_def_count || cmp_scalar(def, offset_defs[i])) { if (i == offset_def_count || cmp_scalar(def, offset_defs[i])) {
/* insert before i */ /* insert before i */
@ -453,9 +457,10 @@ create_entry_key_from_deref(void *mem_ctx,
*offset_base += offset * stride; *offset_base += offset * stride;
if (base.def) { if (base.def) {
uint64_t mul = util_mask_sign_extend(base_mul * stride, index->bit_size);
offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul, offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
offset_def_count, offset_def_count,
base, base_mul * stride); base, mul);
} }
break; break;
} }
@ -493,6 +498,7 @@ static unsigned
parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left, parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
nir_scalar base, uint64_t base_mul, uint64_t *offset) nir_scalar base, uint64_t base_mul, uint64_t *offset)
{ {
unsigned original_bit_size = base.def->bit_size;
uint64_t new_mul; uint64_t new_mul;
uint64_t new_offset; uint64_t new_offset;
parse_offset(&base, &new_mul, &new_offset); parse_offset(&base, &new_mul, &new_offset);
@ -505,7 +511,7 @@ parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
assert(left >= 1); assert(left >= 1);
if (left >= 2) { if (left >= 2 && base.def->bit_size == original_bit_size) {
if (nir_scalar_is_alu(base) && nir_scalar_alu_op(base) == nir_op_iadd) { if (nir_scalar_is_alu(base) && nir_scalar_alu_op(base) == nir_op_iadd) {
nir_scalar src0 = nir_scalar_chase_alu_src(base, 0); nir_scalar src0 = nir_scalar_chase_alu_src(base, 0);
nir_scalar src1 = nir_scalar_chase_alu_src(base, 1); nir_scalar src1 = nir_scalar_chase_alu_src(base, 1);
@ -515,6 +521,7 @@ parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
} }
} }
base_mul = util_mask_sign_extend(base_mul, original_bit_size);
return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul); return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
} }