ac/nir/ngg: Refactor MS primitive indices for scalarized IO.

Previously, it would hit an assertion when used with scalarized
IO, because the scalarization will split the primitive indices
store into smaller, per-component stores.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28704>
This commit is contained in:
Timur Kristóf 2024-04-11 00:45:05 +02:00
parent 76c90f929f
commit 8e24d3426d

View file

@ -3654,22 +3654,38 @@ ac_ngg_get_scratch_lds_size(gl_shader_stage stage,
static void
ms_store_prim_indices(nir_builder *b,
nir_def *val,
nir_def *offset_src,
nir_intrinsic_instr *intrin,
lower_ngg_ms_state *s)
{
assert(val->num_components <= 3);
/* EXT_mesh_shader primitive indices: array of vectors.
* They don't count as per-primitive outputs, but the array is indexed
* by the primitive index, so they are practically per-primitive.
*/
assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
const unsigned component_offset = nir_intrinsic_component(intrin);
nir_def *store_val = intrin->src[0].ssa;
assert(store_val->num_components <= 3);
if (store_val->num_components > s->vertices_per_prim)
store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
for (unsigned c = 0; c < s->vertices_per_prim; ++c)
nir_store_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c], nir_channel(b, val, c), 0x1);
for (unsigned c = 0; c < store_val->num_components; ++c) {
const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
}
return;
}
if (!offset_src)
offset_src = nir_imm_int(b, 0);
nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr);
/* The max vertex count is 256, so these indices always fit 8 bits.
* To reduce LDS use, store these as a flat array of 8-bit values.
*/
nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
}
static void
@ -3822,21 +3838,7 @@ ms_store_arrayed_output_intrin(nir_builder *b,
unsigned location = nir_intrinsic_io_semantics(intrin).location;
if (location == VARYING_SLOT_PRIMITIVE_INDICES) {
/* EXT_mesh_shader primitive indices: array of vectors.
* They don't count as per-primitive outputs, but the array is indexed
* by the primitive index, so they are practically per-primitive.
*
* The max vertex count is 256, so these indices always fit 8 bits.
* To reduce LDS use, store these as a flat array of 8-bit values.
*/
assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
assert(nir_intrinsic_component(intrin) == 0);
nir_def *store_val = intrin->src[0].ssa;
nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
ms_store_prim_indices(b, store_val, offset, s);
ms_store_prim_indices(b, intrin, s);
return;
} else if (location == VARYING_SLOT_CULL_PRIMITIVE) {
/* EXT_mesh_shader cull primitive: per-primitive bool.