diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 4775121233a..dfdcce123d4 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -3654,22 +3654,38 @@ ac_ngg_get_scratch_lds_size(gl_shader_stage stage, static void ms_store_prim_indices(nir_builder *b, - nir_def *val, - nir_def *offset_src, + nir_intrinsic_instr *intrin, lower_ngg_ms_state *s) { - assert(val->num_components <= 3); + /* EXT_mesh_shader primitive indices: array of vectors. + * They don't count as per-primitive outputs, but the array is indexed + * by the primitive index, so they are practically per-primitive. + */ + assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); + assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); + + const unsigned component_offset = nir_intrinsic_component(intrin); + nir_def *store_val = intrin->src[0].ssa; + assert(store_val->num_components <= 3); + + if (store_val->num_components > s->vertices_per_prim) + store_val = nir_trim_vector(b, store_val, s->vertices_per_prim); if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { - for (unsigned c = 0; c < s->vertices_per_prim; ++c) - nir_store_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c], nir_channel(b, val, c), 0x1); + for (unsigned c = 0; c < store_val->num_components; ++c) { + const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset; + nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1); + } return; } - if (!offset_src) - offset_src = nir_imm_int(b, 0); + nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; + nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); - nir_store_shared(b, nir_u2u8(b, val), offset_src, .base = s->layout.lds.indices_addr); + /* The max vertex count is 256, so these indices always fit 8 bits. + * To reduce LDS use, store these as a flat array of 8-bit values. + */ + nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset); } static void @@ -3822,21 +3838,7 @@ ms_store_arrayed_output_intrin(nir_builder *b, unsigned location = nir_intrinsic_io_semantics(intrin).location; if (location == VARYING_SLOT_PRIMITIVE_INDICES) { - /* EXT_mesh_shader primitive indices: array of vectors. - * They don't count as per-primitive outputs, but the array is indexed - * by the primitive index, so they are practically per-primitive. - * - * The max vertex count is 256, so these indices always fit 8 bits. - * To reduce LDS use, store these as a flat array of 8-bit values. - */ - assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); - assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); - assert(nir_intrinsic_component(intrin) == 0); - - nir_def *store_val = intrin->src[0].ssa; - nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; - nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); - ms_store_prim_indices(b, store_val, offset, s); + ms_store_prim_indices(b, intrin, s); return; } else if (location == VARYING_SLOT_CULL_PRIMITIVE) { /* EXT_mesh_shader cull primitive: per-primitive bool.