nir/opt_vectorize_io: optionally vectorize loads with holes

e.g. load X; load W; ==> load XYZW. Verified with a shader test.

This will be used by AMD drivers. See the code comments.

Reviewed-by: Simon Perretta <simon.perretta@imgtec.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36098>
This commit is contained in:
Marek Olšák 2025-07-13 03:49:04 -04:00 committed by Marge Bot
parent b4977a1605
commit 6286c1c66f
5 changed files with 50 additions and 20 deletions

View file

@ -1665,8 +1665,8 @@ radv_graphics_shaders_link_varyings(struct radv_shader_stage *stages)
if (next != MESA_SHADER_NONE && stages[next].nir && next != MESA_SHADER_FRAGMENT &&
!stages[s].key.optimisations_disabled && !stages[next].key.optimisations_disabled) {
nir_shader *consumer = stages[next].nir;
NIR_PASS(_, producer, nir_opt_vectorize_io, nir_var_shader_out);
NIR_PASS(_, consumer, nir_opt_vectorize_io, nir_var_shader_in);
NIR_PASS(_, producer, nir_opt_vectorize_io, nir_var_shader_out, false);
NIR_PASS(_, consumer, nir_opt_vectorize_io, nir_var_shader_in, false);
}
/* Gather shader info; at least the I/O info likely changed

View file

@ -1505,7 +1505,7 @@ gl_nir_lower_optimize_varyings(const struct gl_constants *consts,
*/
NIR_PASS(_, nir, nir_lower_io_to_scalar, get_varying_nir_var_mask(nir),
NULL, NULL);
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir));
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir), false);
return;
}
@ -1569,7 +1569,7 @@ gl_nir_lower_optimize_varyings(const struct gl_constants *consts,
nir_shader *nir = shaders[i];
/* Re-vectorize IO. */
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir));
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir), false);
/* Recompute intrinsic bases, which are totally random after
* optimizations and compaction. Do that for all inputs and outputs,

View file

@ -6264,7 +6264,8 @@ bool nir_opt_uniform_subgroup(nir_shader *shader,
bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
void *data);
bool nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes);
bool nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes,
bool allow_holes);
bool nir_opt_move_discards_to_top(nir_shader *shader);

View file

@ -148,6 +148,9 @@ vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
* inserted.
*/
for (unsigned i = start; i < start + count; i++) {
if (!chan[i])
continue;
first = !first || chan[i]->instr.index < first->instr.index ? chan[i] : first;
if (step == merge_low_high_16_to_32) {
first = !first || chan[4 + i]->instr.index < first->instr.index ? chan[4 + i] : first;
@ -205,7 +208,8 @@ vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
}
} else {
for (unsigned i = start; i < start + count; i++) {
nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
if (chan[i])
nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
}
}
}
@ -360,9 +364,11 @@ vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
* (the last 4 are the high 16-bit channels)
*/
static bool
vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask, bool allow_holes)
{
bool progress = false;
assert(mask);
bool is_load = nir_intrinsic_infos[chan[ffs(mask) - 1]->intrinsic].has_dest;
/* First, merge low and high 16-bit halves into 32 bits separately when
* possible. Then vectorize what's left.
@ -407,8 +413,18 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
} else if (step == vectorize_high_16_separately) {
scan_mask = mask & BITFIELD_RANGE(4, 4);
mask &= ~scan_mask;
if (is_load && allow_holes) {
unsigned num = util_last_bit(scan_mask);
scan_mask = BITFIELD_RANGE(4, num - 4);
}
} else {
scan_mask = mask;
if (is_load && allow_holes) {
unsigned num = util_last_bit(scan_mask);
scan_mask = BITFIELD_MASK(num);
}
}
while (scan_mask) {
@ -419,8 +435,6 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
if (count == 1 && step != merge_low_high_16_to_32)
continue; /* There is nothing to vectorize. */
bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest;
if (is_load)
vectorize_load(chan, start, count, step);
else
@ -434,7 +448,7 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
}
static bool
vectorize_batch(struct util_dynarray *io_instructions)
vectorize_batch(struct util_dynarray *io_instructions, bool allow_holes)
{
unsigned num_instr = util_dynarray_num_elements(io_instructions, void *);
@ -473,7 +487,7 @@ vectorize_batch(struct util_dynarray *io_instructions)
if (prev && compare_is_not_vectorizable(prev, *intr)) {
/* We need at least 2 instructions to have something to do. */
if (util_bitcount(chan_mask) > 1)
progress |= vectorize_slot(chan, chan_mask);
progress |= vectorize_slot(chan, chan_mask, allow_holes);
prev = NULL;
memset(chan, 0, sizeof(chan));
@ -497,15 +511,28 @@ vectorize_batch(struct util_dynarray *io_instructions)
/* Vectorize the last group. */
if (prev && util_bitcount(chan_mask) > 1)
progress |= vectorize_slot(chan, chan_mask);
progress |= vectorize_slot(chan, chan_mask, allow_holes);
/* Clear the array. The next block will reuse it. */
util_dynarray_clear(io_instructions);
return progress;
}
/* Vectorize lowered IO (load_input/store_output/...).
*
* modes specifies whether to vectorize inputs and/or outputs.
*
* allow_holes enables vectorization of loads with holes, e.g.:
* load X; load W; ==> load XYZW;
*
* This is useful for VS input loads where it might not be possible to skip
* loading unused components, e.g. with AMD where loading W also loads XYZ,
* so if we also load X separately again, it's wasteful. It's better to get
* X from the vector that loads (XYZ)W.
*/
bool
nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes,
bool allow_holes)
{
assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));
@ -520,8 +547,10 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
* but that is only done when outputs are ignored, so vectorize them
* separately.
*/
bool progress_in = nir_opt_vectorize_io(shader, nir_var_shader_in);
bool progress_out = nir_opt_vectorize_io(shader, nir_var_shader_out);
bool progress_in = nir_opt_vectorize_io(shader, nir_var_shader_in,
allow_holes);
bool progress_out = nir_opt_vectorize_io(shader, nir_var_shader_out,
allow_holes);
return progress_in || progress_out;
}
@ -584,7 +613,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
*/
if (BITSET_TEST(is_load ? has_output_stores : has_output_loads,
index)) {
progress |= vectorize_batch(&io_instructions);
progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores);
}
@ -595,7 +624,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
/* Don't vectorize across TCS barriers. */
if (modes & nir_var_shader_out &&
nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
progress |= vectorize_batch(&io_instructions);
progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores);
}
@ -603,7 +632,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
case nir_intrinsic_emit_vertex:
/* Don't vectorize across GS emits. */
progress |= vectorize_batch(&io_instructions);
progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores);
continue;
@ -622,7 +651,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
BITSET_SET(is_load ? has_output_loads : has_output_stores, index);
}
progress |= vectorize_batch(&io_instructions);
progress |= vectorize_batch(&io_instructions, allow_holes);
}
nir_progress(progress, impl,

View file

@ -287,7 +287,7 @@ void pco_lower_nir(pco_ctx *ctx, nir_shader *nir, pco_data *data)
if (nir->info.stage != MESA_SHADER_FRAGMENT)
vec_modes |= nir_var_shader_out;
NIR_PASS(_, nir, nir_opt_vectorize_io, vec_modes);
NIR_PASS(_, nir, nir_opt_vectorize_io, vec_modes, false);
/* Special case for frag coords:
* - x,y come from (non-consecutive) special regs - always scalar.