nir/opt_vectorize_io: optionally vectorize loads with holes

e.g. load X; load W; ==> load XYZW. Verified with a shader test.

This will be used by AMD drivers. See the code comments.

Reviewed-by: Simon Perretta <simon.perretta@imgtec.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36098>
This commit is contained in:
Marek Olšák 2025-07-13 03:49:04 -04:00 committed by Marge Bot
parent b4977a1605
commit 6286c1c66f
5 changed files with 50 additions and 20 deletions

View file

@ -1665,8 +1665,8 @@ radv_graphics_shaders_link_varyings(struct radv_shader_stage *stages)
if (next != MESA_SHADER_NONE && stages[next].nir && next != MESA_SHADER_FRAGMENT && if (next != MESA_SHADER_NONE && stages[next].nir && next != MESA_SHADER_FRAGMENT &&
!stages[s].key.optimisations_disabled && !stages[next].key.optimisations_disabled) { !stages[s].key.optimisations_disabled && !stages[next].key.optimisations_disabled) {
nir_shader *consumer = stages[next].nir; nir_shader *consumer = stages[next].nir;
NIR_PASS(_, producer, nir_opt_vectorize_io, nir_var_shader_out); NIR_PASS(_, producer, nir_opt_vectorize_io, nir_var_shader_out, false);
NIR_PASS(_, consumer, nir_opt_vectorize_io, nir_var_shader_in); NIR_PASS(_, consumer, nir_opt_vectorize_io, nir_var_shader_in, false);
} }
/* Gather shader info; at least the I/O info likely changed /* Gather shader info; at least the I/O info likely changed

View file

@ -1505,7 +1505,7 @@ gl_nir_lower_optimize_varyings(const struct gl_constants *consts,
*/ */
NIR_PASS(_, nir, nir_lower_io_to_scalar, get_varying_nir_var_mask(nir), NIR_PASS(_, nir, nir_lower_io_to_scalar, get_varying_nir_var_mask(nir),
NULL, NULL); NULL, NULL);
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir)); NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir), false);
return; return;
} }
@ -1569,7 +1569,7 @@ gl_nir_lower_optimize_varyings(const struct gl_constants *consts,
nir_shader *nir = shaders[i]; nir_shader *nir = shaders[i];
/* Re-vectorize IO. */ /* Re-vectorize IO. */
NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir)); NIR_PASS(_, nir, nir_opt_vectorize_io, get_varying_nir_var_mask(nir), false);
/* Recompute intrinsic bases, which are totally random after /* Recompute intrinsic bases, which are totally random after
* optimizations and compaction. Do that for all inputs and outputs, * optimizations and compaction. Do that for all inputs and outputs,

View file

@ -6264,7 +6264,8 @@ bool nir_opt_uniform_subgroup(nir_shader *shader,
bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
void *data); void *data);
bool nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes); bool nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes,
bool allow_holes);
bool nir_opt_move_discards_to_top(nir_shader *shader); bool nir_opt_move_discards_to_top(nir_shader *shader);

View file

@ -148,6 +148,9 @@ vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
* inserted. * inserted.
*/ */
for (unsigned i = start; i < start + count; i++) { for (unsigned i = start; i < start + count; i++) {
if (!chan[i])
continue;
first = !first || chan[i]->instr.index < first->instr.index ? chan[i] : first; first = !first || chan[i]->instr.index < first->instr.index ? chan[i] : first;
if (step == merge_low_high_16_to_32) { if (step == merge_low_high_16_to_32) {
first = !first || chan[4 + i]->instr.index < first->instr.index ? chan[4 + i] : first; first = !first || chan[4 + i]->instr.index < first->instr.index ? chan[4 + i] : first;
@ -205,7 +208,8 @@ vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
} }
} else { } else {
for (unsigned i = start; i < start + count; i++) { for (unsigned i = start; i < start + count; i++) {
nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start)); if (chan[i])
nir_def_replace(&chan[i]->def, nir_channel(&b, def, i - start));
} }
} }
} }
@ -360,9 +364,11 @@ vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count,
* (the last 4 are the high 16-bit channels) * (the last 4 are the high 16-bit channels)
*/ */
static bool static bool
vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask) vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask, bool allow_holes)
{ {
bool progress = false; bool progress = false;
assert(mask);
bool is_load = nir_intrinsic_infos[chan[ffs(mask) - 1]->intrinsic].has_dest;
/* First, merge low and high 16-bit halves into 32 bits separately when /* First, merge low and high 16-bit halves into 32 bits separately when
* possible. Then vectorize what's left. * possible. Then vectorize what's left.
@ -407,8 +413,18 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
} else if (step == vectorize_high_16_separately) { } else if (step == vectorize_high_16_separately) {
scan_mask = mask & BITFIELD_RANGE(4, 4); scan_mask = mask & BITFIELD_RANGE(4, 4);
mask &= ~scan_mask; mask &= ~scan_mask;
if (is_load && allow_holes) {
unsigned num = util_last_bit(scan_mask);
scan_mask = BITFIELD_RANGE(4, num - 4);
}
} else { } else {
scan_mask = mask; scan_mask = mask;
if (is_load && allow_holes) {
unsigned num = util_last_bit(scan_mask);
scan_mask = BITFIELD_MASK(num);
}
} }
while (scan_mask) { while (scan_mask) {
@ -419,8 +435,6 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
if (count == 1 && step != merge_low_high_16_to_32) if (count == 1 && step != merge_low_high_16_to_32)
continue; /* There is nothing to vectorize. */ continue; /* There is nothing to vectorize. */
bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest;
if (is_load) if (is_load)
vectorize_load(chan, start, count, step); vectorize_load(chan, start, count, step);
else else
@ -434,7 +448,7 @@ vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask)
} }
static bool static bool
vectorize_batch(struct util_dynarray *io_instructions) vectorize_batch(struct util_dynarray *io_instructions, bool allow_holes)
{ {
unsigned num_instr = util_dynarray_num_elements(io_instructions, void *); unsigned num_instr = util_dynarray_num_elements(io_instructions, void *);
@ -473,7 +487,7 @@ vectorize_batch(struct util_dynarray *io_instructions)
if (prev && compare_is_not_vectorizable(prev, *intr)) { if (prev && compare_is_not_vectorizable(prev, *intr)) {
/* We need at least 2 instructions to have something to do. */ /* We need at least 2 instructions to have something to do. */
if (util_bitcount(chan_mask) > 1) if (util_bitcount(chan_mask) > 1)
progress |= vectorize_slot(chan, chan_mask); progress |= vectorize_slot(chan, chan_mask, allow_holes);
prev = NULL; prev = NULL;
memset(chan, 0, sizeof(chan)); memset(chan, 0, sizeof(chan));
@ -497,15 +511,28 @@ vectorize_batch(struct util_dynarray *io_instructions)
/* Vectorize the last group. */ /* Vectorize the last group. */
if (prev && util_bitcount(chan_mask) > 1) if (prev && util_bitcount(chan_mask) > 1)
progress |= vectorize_slot(chan, chan_mask); progress |= vectorize_slot(chan, chan_mask, allow_holes);
/* Clear the array. The next block will reuse it. */ /* Clear the array. The next block will reuse it. */
util_dynarray_clear(io_instructions); util_dynarray_clear(io_instructions);
return progress; return progress;
} }
/* Vectorize lowered IO (load_input/store_output/...).
*
* modes specifies whether to vectorize inputs and/or outputs.
*
* allow_holes enables vectorization of loads with holes, e.g.:
* load X; load W; ==> load XYZW;
*
* This is useful for VS input loads where it might not be possible to skip
* loading unused components, e.g. with AMD where loading W also loads XYZ,
* so if we also load X separately again, it's wasteful. It's better to get
* X from the vector that loads (XYZ)W.
*/
bool bool
nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes) nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes,
bool allow_holes)
{ {
assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out))); assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out)));
@ -520,8 +547,10 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
* but that is only done when outputs are ignored, so vectorize them * but that is only done when outputs are ignored, so vectorize them
* separately. * separately.
*/ */
bool progress_in = nir_opt_vectorize_io(shader, nir_var_shader_in); bool progress_in = nir_opt_vectorize_io(shader, nir_var_shader_in,
bool progress_out = nir_opt_vectorize_io(shader, nir_var_shader_out); allow_holes);
bool progress_out = nir_opt_vectorize_io(shader, nir_var_shader_out,
allow_holes);
return progress_in || progress_out; return progress_in || progress_out;
} }
@ -584,7 +613,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
*/ */
if (BITSET_TEST(is_load ? has_output_stores : has_output_loads, if (BITSET_TEST(is_load ? has_output_stores : has_output_loads,
index)) { index)) {
progress |= vectorize_batch(&io_instructions); progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads); BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores); BITSET_ZERO(has_output_stores);
} }
@ -595,7 +624,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
/* Don't vectorize across TCS barriers. */ /* Don't vectorize across TCS barriers. */
if (modes & nir_var_shader_out && if (modes & nir_var_shader_out &&
nir_intrinsic_memory_modes(intr) & nir_var_shader_out) { nir_intrinsic_memory_modes(intr) & nir_var_shader_out) {
progress |= vectorize_batch(&io_instructions); progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads); BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores); BITSET_ZERO(has_output_stores);
} }
@ -603,7 +632,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
case nir_intrinsic_emit_vertex: case nir_intrinsic_emit_vertex:
/* Don't vectorize across GS emits. */ /* Don't vectorize across GS emits. */
progress |= vectorize_batch(&io_instructions); progress |= vectorize_batch(&io_instructions, allow_holes);
BITSET_ZERO(has_output_loads); BITSET_ZERO(has_output_loads);
BITSET_ZERO(has_output_stores); BITSET_ZERO(has_output_stores);
continue; continue;
@ -622,7 +651,7 @@ nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes)
BITSET_SET(is_load ? has_output_loads : has_output_stores, index); BITSET_SET(is_load ? has_output_loads : has_output_stores, index);
} }
progress |= vectorize_batch(&io_instructions); progress |= vectorize_batch(&io_instructions, allow_holes);
} }
nir_progress(progress, impl, nir_progress(progress, impl,

View file

@ -287,7 +287,7 @@ void pco_lower_nir(pco_ctx *ctx, nir_shader *nir, pco_data *data)
if (nir->info.stage != MESA_SHADER_FRAGMENT) if (nir->info.stage != MESA_SHADER_FRAGMENT)
vec_modes |= nir_var_shader_out; vec_modes |= nir_var_shader_out;
NIR_PASS(_, nir, nir_opt_vectorize_io, vec_modes); NIR_PASS(_, nir, nir_opt_vectorize_io, vec_modes, false);
/* Special case for frag coords: /* Special case for frag coords:
* - x,y come from (non-consecutive) special regs - always scalar. * - x,y come from (non-consecutive) special regs - always scalar.