nir/load_store_vectorize: optimize accesses with u2u64(ishl.nuw(iadd))

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37163>
This commit is contained in:
Rhys Perry 2025-09-02 10:20:09 +01:00 committed by Marge Bot
parent 4bc4322150
commit cfba417316

View file

@ -132,10 +132,18 @@ get_info(nir_intrinsic_op op)
return NULL;
}
/* Represents "s * mul". */
/* Represents "s * mul" or "u2u64(s + add32) * mul" (if "s" is 32-bit and the
* final offset/address is 64-bit).
*
* Given two terms s1 and s2 with the same "s":
* "u2u64(s + s1.add32) - u2u64(s + s2.add32) == (u2u64(s1.add32) - u2u64(s2.add32))"
* This is because of the checks in parse_offset() which ensures each addition
* only overflows if the other does too.
*/
struct offset_term {
nir_scalar s;
uint64_t mul;
uint64_t add32;
};
/*
@ -163,10 +171,16 @@ struct entry {
unsigned index;
struct entry_key *key;
/* The constant offset is sign-extended to 64 bits. */
union {
uint64_t offset; /* sign-extended */
uint64_t offset;
int64_t offset_signed;
};
/* Total of each offset_term::add32 multiplied by offset_term::mul. We don't
* need to keep each individual one.
*/
uint64_t total_add32;
uint32_t align_mul;
uint32_t align_offset;
@ -258,15 +272,25 @@ delete_entry_dynarray(struct hash_entry *entry)
ralloc_free(arr);
}
static int64_t
get_offset_diff(struct entry *a, struct entry *b)
{
assert(entry_key_equals(a->key, b->key));
int64_t diff = b->offset_signed - a->offset_signed;
diff += b->total_add32 - a->total_add32;
return diff;
}
static int
sort_entries(const void *a_, const void *b_)
{
struct entry *a = *(struct entry *const *)a_;
struct entry *b = *(struct entry *const *)b_;
if (a->offset_signed > b->offset_signed)
int64_t diff = get_offset_diff(b, a);
if (diff > 0)
return 1;
else if (a->offset_signed < b->offset_signed)
else if (diff < 0)
return -1;
if (a->index > b->index)
@ -353,27 +377,73 @@ parse_offset(nir_scalar base, uint64_t *offset)
uint64_t add = 0;
bool progress = false;
bool require_nuw = false;
uint64_t uub = u_uintN_max(base.def->bit_size);
do {
uint64_t mul2 = 1, add2 = 0;
progress = false;
progress = parse_alu(&base, nir_op_imul, &mul2, require_nuw);
mul *= mul2;
if (parse_alu(&base, nir_op_imul, &mul2, require_nuw)) {
progress = true;
uub = mul2 ? uub / mul2 : 0;
mul *= mul2;
}
mul2 = 0;
progress |= parse_alu(&base, nir_op_ishl, &mul2, require_nuw);
mul <<= mul2;
if (parse_alu(&base, nir_op_ishl, &mul2, require_nuw)) {
progress = true;
uub >>= mul2 & (base.def->bit_size - 1);
mul <<= mul2 & (base.def->bit_size - 1);
}
progress |= parse_alu(&base, nir_op_iadd, &add2, require_nuw);
add += add2 * mul;
if (parse_alu(&base, nir_op_iadd, &add2, require_nuw)) {
progress = true;
uub = u_uintN_max(base.def->bit_size);
add += add2 * mul;
}
if (nir_scalar_is_alu(base) && (nir_scalar_alu_op(base) == nir_op_mov ||
nir_scalar_alu_op(base) == nir_op_u2u64)) {
require_nuw |= nir_scalar_alu_op(base) == nir_op_u2u64;
if (nir_scalar_alu_op(base) == nir_op_u2u64) {
require_nuw = true;
uub = u_uintN_max(base.def->bit_size);
}
base = nir_scalar_chase_alu_src(base, 0);
progress = true;
}
} while (progress);
nir_scalar base32 = base;
uint64_t add32 = 0;
if (require_nuw && parse_alu(&base32, nir_op_iadd, &add32, false)) {
/* base32 + add32 is in [0,uub].
*
* The addition overflows if base32 is in:
* - (uint_max-add32,uint_max] if add32 <= uub
* - (uint_max-add32,uint_max-add32+uub+1] if add32 > uub
*
* The addition does not overflow if base32 is in:
* - [0,uub-add32] if add32 <= uub
*
* If the overflow and no-overflow intervals of "base32 + add32_0" and
* "base32 + add32_1" do not intersect, then:
* - one addition overflows if and only if the other does
* - and "(u2u64(base32) + add32_0) - (u2u64(base32) + add32_1) == (u2u64(add32_0) - u2u64(add32_1))"
*
* Instead of checking whether the intervals of two entries intersect,
* we just ensure they're all a subset of a shared fixed interval:
* - [0,(uint_max+1)/2) for the no-overflow interval
* - [(uint_max+1)/2,uint_max] for the overflow interval
*/
uint32_t uint_max = u_uintN_max(base32.def->bit_size);
uint32_t ovfl_interval_start = uint_max - add32;
uint32_t noovfl_interval_end = add32 <= uub ? uub - add32 : 0;
uint32_t mid = ((uint64_t)uint_max + 1) / 2u;
if (ovfl_interval_start >= (mid - 1) && noovfl_interval_end < mid) {
base = base32;
} else {
add32 = 0;
}
}
if (base.def->parent_instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrin = nir_def_as_intrinsic(base.def);
if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
@ -383,6 +453,7 @@ parse_offset(nir_scalar base, uint64_t *offset)
*offset = add;
term.s = base;
term.mul = mul;
term.add32 = add32;
return term;
}
@ -395,22 +466,28 @@ type_scalar_size_bytes(const struct glsl_type *type)
}
static bool
cmp_scalar(nir_scalar a, nir_scalar b)
cmp_term(struct offset_term a, struct offset_term b)
{
return a.def == b.def ? a.comp > b.comp : a.def->index > b.def->index;
if (a.s.def != b.s.def)
return a.s.def->index > b.s.def->index;
if (a.s.comp != b.s.comp)
return a.s.comp > b.s.comp;
return a.add32 > b.add32;
}
static unsigned
add_to_entry_key(struct offset_term *terms, unsigned count, struct offset_term term)
{
for (unsigned i = 0; i <= count; i++) {
if (i == count || cmp_scalar(term.s, terms[i].s)) {
if (i == count || cmp_term(term, terms[i])) {
/* insert before i */
memmove(terms + i + 1, terms + i,
(count - i) * sizeof(struct offset_term));
terms[i] = term;
return 1;
} else if (nir_scalar_equal(term.s, terms[i].s)) {
} else if (nir_scalar_equal(term.s, terms[i].s) && !terms[i].add32 && !term.add32) {
/* merge with offset_def at i */
terms[i].mul += term.mul;
return 0;
@ -432,7 +509,13 @@ fill_in_offset_defs(struct vectorize_ctx *ctx, struct entry *entry,
for (unsigned i = 0; i < count; i++) {
key->offset_defs[i] = terms[i].s;
key->offset_defs_mul[i] = terms[i].mul;
key->offset_def_num_lsbz[i] = nir_def_num_lsb_zero(ctx->numlsb_ht, terms[i].s);
unsigned lsb_zero = nir_def_num_lsb_zero(ctx->numlsb_ht, terms[i].s);
if (terms[i].add32)
lsb_zero = MIN2(lsb_zero, ffsll(terms[i].add32) - 1);
key->offset_def_num_lsbz[i] = lsb_zero;
entry->total_add32 += terms[i].add32 * terms[i].mul;
}
}
@ -733,7 +816,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
return false;
}
unsigned high_offset = high->offset_signed - low->offset_signed;
unsigned high_offset = get_offset_diff(low, high);
/* This can cause issues when combining store data. */
if (high_offset % (new_bit_size / 8) != 0)
@ -749,7 +832,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
unsigned low_size = low->intrin->num_components * get_bit_size(low) / 8;
/* The hole size can be less than 0 if low and high instructions overlap. */
int64_t hole_size = high->offset_signed - (low->offset_signed + low_size);
int64_t hole_size = (int64_t)high_offset - low_size;
if (!ctx->options->callback(low->align_mul,
low->align_offset,
@ -936,6 +1019,7 @@ vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
first->key = low->key;
first->offset = low->offset;
first->total_add32 = low->total_add32;
first->align_mul = low->align_mul;
first->align_offset = low->align_offset;
@ -1031,6 +1115,7 @@ vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
second->key = low->key;
second->offset = low->offset;
second->total_add32 = low->total_add32;
second->align_mul = low->align_mul;
second->align_offset = low->align_offset;
@ -1103,7 +1188,7 @@ may_alias_internal(struct entry *a, struct entry *b, uint32_t a_offset, uint32_t
if (!entry_key_equals(a->key, b->key))
return true;
int64_t diff = (b->offset_signed + b_offset) - (a->offset_signed + a_offset);
int64_t diff = get_offset_diff(a, b) + b_offset - a_offset;
struct entry *first = diff < 0 ? b : a;
unsigned size = get_bit_size(first) / 8u * first->num_components;
@ -1336,7 +1421,7 @@ try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
if (!can_vectorize(ctx, first, second))
return false;
uint64_t diff = high->offset_signed - low->offset_signed;
uint64_t diff = get_offset_diff(low, high);
if (check_for_robustness(ctx, low, diff))
return false;
@ -1410,7 +1495,7 @@ try_vectorize_shared2(struct vectorize_ctx *ctx,
if (high->align_mul % low_size || high->align_offset % low_size)
return false;
uint64_t diff = high->offset_signed - low->offset_signed;
uint64_t diff = get_offset_diff(low, high);
bool st64 = diff % (64 * low_size) == 0;
unsigned stride = st64 ? 64 * low_size : low_size;
if (diff % stride || diff > 255 * stride)
@ -1488,7 +1573,7 @@ vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
struct entry *first = low->index < high->index ? low : high;
struct entry *second = low->index < high->index ? high : low;
uint64_t diff = high->offset_signed - low->offset_signed;
uint64_t diff = get_offset_diff(low, high);
/* Allow overfetching by 28 bytes, which can be rejected by the
* callback if needed. Driver callbacks will likely want to
* restrict this to a smaller value, say 4 bytes (or none).