diff --git a/src/compiler/nir/nir_opt_load_store_vectorize.c b/src/compiler/nir/nir_opt_load_store_vectorize.c index afac1f9fdbc..6b75446f5d9 100644 --- a/src/compiler/nir/nir_opt_load_store_vectorize.c +++ b/src/compiler/nir/nir_opt_load_store_vectorize.c @@ -132,10 +132,18 @@ get_info(nir_intrinsic_op op) return NULL; } -/* Represents "s * mul". */ +/* Represents "s * mul" or "u2u64(s + add32) * mul" (if "s" is 32-bit and the + * final offset/address is 64-bit). + * + * Given two terms s1 and s2 with the same "s": + * "u2u64(s + s1.add32) - u2u64(s + s2.add32) == (u2u64(s1.add32) - u2u64(s2.add32))" + * This is because of the checks in parse_offset() which ensures each addition + * only overflows if the other does too. + */ struct offset_term { nir_scalar s; uint64_t mul; + uint64_t add32; }; /* @@ -163,10 +171,16 @@ struct entry { unsigned index; struct entry_key *key; + /* The constant offset is sign-extended to 64 bits. */ union { - uint64_t offset; /* sign-extended */ + uint64_t offset; int64_t offset_signed; }; + /* Total of each offset_term::add32 multiplied by offset_term::mul. We don't + * need to keep each individual one. + */ + uint64_t total_add32; + uint32_t align_mul; uint32_t align_offset; @@ -258,15 +272,25 @@ delete_entry_dynarray(struct hash_entry *entry) ralloc_free(arr); } +static int64_t +get_offset_diff(struct entry *a, struct entry *b) +{ + assert(entry_key_equals(a->key, b->key)); + int64_t diff = b->offset_signed - a->offset_signed; + diff += b->total_add32 - a->total_add32; + return diff; +} + static int sort_entries(const void *a_, const void *b_) { struct entry *a = *(struct entry *const *)a_; struct entry *b = *(struct entry *const *)b_; - if (a->offset_signed > b->offset_signed) + int64_t diff = get_offset_diff(b, a); + if (diff > 0) return 1; - else if (a->offset_signed < b->offset_signed) + else if (diff < 0) return -1; if (a->index > b->index) @@ -353,27 +377,73 @@ parse_offset(nir_scalar base, uint64_t *offset) uint64_t add = 0; bool progress = false; bool require_nuw = false; + uint64_t uub = u_uintN_max(base.def->bit_size); do { uint64_t mul2 = 1, add2 = 0; + progress = false; - progress = parse_alu(&base, nir_op_imul, &mul2, require_nuw); - mul *= mul2; + if (parse_alu(&base, nir_op_imul, &mul2, require_nuw)) { + progress = true; + uub = mul2 ? uub / mul2 : 0; + mul *= mul2; + } - mul2 = 0; - progress |= parse_alu(&base, nir_op_ishl, &mul2, require_nuw); - mul <<= mul2; + if (parse_alu(&base, nir_op_ishl, &mul2, require_nuw)) { + progress = true; + uub >>= mul2 & (base.def->bit_size - 1); + mul <<= mul2 & (base.def->bit_size - 1); + } - progress |= parse_alu(&base, nir_op_iadd, &add2, require_nuw); - add += add2 * mul; + if (parse_alu(&base, nir_op_iadd, &add2, require_nuw)) { + progress = true; + uub = u_uintN_max(base.def->bit_size); + add += add2 * mul; + } if (nir_scalar_is_alu(base) && (nir_scalar_alu_op(base) == nir_op_mov || nir_scalar_alu_op(base) == nir_op_u2u64)) { - require_nuw |= nir_scalar_alu_op(base) == nir_op_u2u64; + if (nir_scalar_alu_op(base) == nir_op_u2u64) { + require_nuw = true; + uub = u_uintN_max(base.def->bit_size); + } base = nir_scalar_chase_alu_src(base, 0); progress = true; } } while (progress); + nir_scalar base32 = base; + uint64_t add32 = 0; + if (require_nuw && parse_alu(&base32, nir_op_iadd, &add32, false)) { + /* base32 + add32 is in [0,uub]. + * + * The addition overflows if base32 is in: + * - (uint_max-add32,uint_max] if add32 <= uub + * - (uint_max-add32,uint_max-add32+uub+1] if add32 > uub + * + * The addition does not overflow if base32 is in: + * - [0,uub-add32] if add32 <= uub + * + * If the overflow and no-overflow intervals of "base32 + add32_0" and + * "base32 + add32_1" do not intersect, then: + * - one addition overflows if and only if the other does + * - and "(u2u64(base32) + add32_0) - (u2u64(base32) + add32_1) == (u2u64(add32_0) - u2u64(add32_1))" + * + * Instead of checking whether the intervals of two entries intersect, + * we just ensure they're all a subset of a shared fixed interval: + * - [0,(uint_max+1)/2) for the no-overflow interval + * - [(uint_max+1)/2,uint_max] for the overflow interval + */ + uint32_t uint_max = u_uintN_max(base32.def->bit_size); + uint32_t ovfl_interval_start = uint_max - add32; + uint32_t noovfl_interval_end = add32 <= uub ? uub - add32 : 0; + uint32_t mid = ((uint64_t)uint_max + 1) / 2u; + if (ovfl_interval_start >= (mid - 1) && noovfl_interval_end < mid) { + base = base32; + } else { + add32 = 0; + } + } + if (base.def->parent_instr->type == nir_instr_type_intrinsic) { nir_intrinsic_instr *intrin = nir_def_as_intrinsic(base.def); if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor) @@ -383,6 +453,7 @@ parse_offset(nir_scalar base, uint64_t *offset) *offset = add; term.s = base; term.mul = mul; + term.add32 = add32; return term; } @@ -395,22 +466,28 @@ type_scalar_size_bytes(const struct glsl_type *type) } static bool -cmp_scalar(nir_scalar a, nir_scalar b) +cmp_term(struct offset_term a, struct offset_term b) { - return a.def == b.def ? a.comp > b.comp : a.def->index > b.def->index; + if (a.s.def != b.s.def) + return a.s.def->index > b.s.def->index; + + if (a.s.comp != b.s.comp) + return a.s.comp > b.s.comp; + + return a.add32 > b.add32; } static unsigned add_to_entry_key(struct offset_term *terms, unsigned count, struct offset_term term) { for (unsigned i = 0; i <= count; i++) { - if (i == count || cmp_scalar(term.s, terms[i].s)) { + if (i == count || cmp_term(term, terms[i])) { /* insert before i */ memmove(terms + i + 1, terms + i, (count - i) * sizeof(struct offset_term)); terms[i] = term; return 1; - } else if (nir_scalar_equal(term.s, terms[i].s)) { + } else if (nir_scalar_equal(term.s, terms[i].s) && !terms[i].add32 && !term.add32) { /* merge with offset_def at i */ terms[i].mul += term.mul; return 0; @@ -432,7 +509,13 @@ fill_in_offset_defs(struct vectorize_ctx *ctx, struct entry *entry, for (unsigned i = 0; i < count; i++) { key->offset_defs[i] = terms[i].s; key->offset_defs_mul[i] = terms[i].mul; - key->offset_def_num_lsbz[i] = nir_def_num_lsb_zero(ctx->numlsb_ht, terms[i].s); + + unsigned lsb_zero = nir_def_num_lsb_zero(ctx->numlsb_ht, terms[i].s); + if (terms[i].add32) + lsb_zero = MIN2(lsb_zero, ffsll(terms[i].add32) - 1); + key->offset_def_num_lsbz[i] = lsb_zero; + + entry->total_add32 += terms[i].add32 * terms[i].mul; } } @@ -733,7 +816,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size, return false; } - unsigned high_offset = high->offset_signed - low->offset_signed; + unsigned high_offset = get_offset_diff(low, high); /* This can cause issues when combining store data. */ if (high_offset % (new_bit_size / 8) != 0) @@ -749,7 +832,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size, unsigned low_size = low->intrin->num_components * get_bit_size(low) / 8; /* The hole size can be less than 0 if low and high instructions overlap. */ - int64_t hole_size = high->offset_signed - (low->offset_signed + low_size); + int64_t hole_size = (int64_t)high_offset - low_size; if (!ctx->options->callback(low->align_mul, low->align_offset, @@ -936,6 +1019,7 @@ vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx, first->key = low->key; first->offset = low->offset; + first->total_add32 = low->total_add32; first->align_mul = low->align_mul; first->align_offset = low->align_offset; @@ -1031,6 +1115,7 @@ vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx, second->key = low->key; second->offset = low->offset; + second->total_add32 = low->total_add32; second->align_mul = low->align_mul; second->align_offset = low->align_offset; @@ -1103,7 +1188,7 @@ may_alias_internal(struct entry *a, struct entry *b, uint32_t a_offset, uint32_t if (!entry_key_equals(a->key, b->key)) return true; - int64_t diff = (b->offset_signed + b_offset) - (a->offset_signed + a_offset); + int64_t diff = get_offset_diff(a, b) + b_offset - a_offset; struct entry *first = diff < 0 ? b : a; unsigned size = get_bit_size(first) / 8u * first->num_components; @@ -1336,7 +1421,7 @@ try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx, if (!can_vectorize(ctx, first, second)) return false; - uint64_t diff = high->offset_signed - low->offset_signed; + uint64_t diff = get_offset_diff(low, high); if (check_for_robustness(ctx, low, diff)) return false; @@ -1410,7 +1495,7 @@ try_vectorize_shared2(struct vectorize_ctx *ctx, if (high->align_mul % low_size || high->align_offset % low_size) return false; - uint64_t diff = high->offset_signed - low->offset_signed; + uint64_t diff = get_offset_diff(low, high); bool st64 = diff % (64 * low_size) == 0; unsigned stride = st64 ? 64 * low_size : low_size; if (diff % stride || diff > 255 * stride) @@ -1488,7 +1573,7 @@ vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct entry *first = low->index < high->index ? low : high; struct entry *second = low->index < high->index ? high : low; - uint64_t diff = high->offset_signed - low->offset_signed; + uint64_t diff = get_offset_diff(low, high); /* Allow overfetching by 28 bytes, which can be rejected by the * callback if needed. Driver callbacks will likely want to * restrict this to a smaller value, say 4 bytes (or none).