agx/nir_lower_gs: clean up state/info duplication

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Mary Guillemard <mary.guillemard@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34638>
This commit is contained in:
Alyssa Rosenzweig 2025-04-21 13:04:59 -04:00 committed by Marge Bot
parent 753e3ba55b
commit 1017095c5a

View file

@ -28,23 +28,14 @@ struct lower_gs_state {
int static_count[MAX_VERTEX_STREAMS];
nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
/* Per-input primitive stride of the output index buffer */
unsigned max_indices;
/* The count buffer contains `count_stride_el` 32-bit words in a row for each
* input primitive, for `input_primitives * count_stride_el * 4` total bytes.
*/
unsigned count_stride_el;
/* The index of each counter in the count buffer, or -1 if it's not in the
* count buffer.
*
* Invariant: count_stride_el == sum(count_index[i][j] >= 0).
* Invariant: info->count_words == sum(count_index[i] >= 0).
*/
int count_index[MAX_VERTEX_STREAMS];
bool rasterizer_discard;
bool prefix_summing;
struct agx_gs_info *info;
};
@ -294,7 +285,7 @@ load_xfb_count_address(nir_builder *b, struct lower_gs_state *state,
return NULL;
nir_def *prim_offset_el =
nir_imul_imm(b, unrolled_id, state->count_stride_el);
nir_imul_imm(b, unrolled_id, state->info->count_words);
nir_def *offset_el = nir_iadd_imm(b, prim_offset_el, index);
@ -308,14 +299,14 @@ write_xfb_counts(nir_builder *b, nir_intrinsic_instr *intr,
{
/* Store each required counter */
nir_def *id =
state->prefix_summing ? calc_unrolled_id(b) : nir_imm_int(b, 0);
state->info->prefix_sum ? calc_unrolled_id(b) : nir_imm_int(b, 0);
nir_def *addr =
load_xfb_count_address(b, state, id, nir_intrinsic_stream_id(intr));
if (!addr)
return;
if (state->prefix_summing) {
if (state->info->prefix_sum) {
nir_store_global(b, addr, 4, intr->src[2].ssa, nir_component_mask(1));
} else {
nir_global_atomic(b, 32, addr, intr->src[2].ssa,
@ -629,8 +620,9 @@ agx_nir_create_gs_rast_shader(const nir_shader *gs, bool *side_effects_for_rast,
} else {
/* vertex ID = unrolled (see calc_unrolled_index_id), no instancing */
nir_def *raw_id = nir_load_vertex_id(b);
unsigned stride = state->info->indexed ? output_vertex_id_pot_stride(gs)
: MAX2(state->max_indices, 1);
unsigned stride = state->info->indexed
? output_vertex_id_pot_stride(gs)
: MAX2(state->info->max_indices, 1);
output_id = nir_umod_imm(b, raw_id, stride);
unrolled = nir_udiv_imm(b, raw_id, stride);
@ -738,7 +730,7 @@ previous_xfb_primitives(nir_builder *b, struct lower_gs_state *state,
* we can calculate the base.
*/
return nir_imul_imm(b, unrolled_id, static_count);
} else if (state->prefix_summing) {
} else if (state->info->prefix_sum) {
/* If we prefix summed, load from the sum buffer. Note that the sums are
* inclusive, so index 0 is nonzero. This requires a little fixup here. We
* use a saturating unsigned subtraction so we don't read out-of-bounds.
@ -776,7 +768,7 @@ lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
libagx_end_primitive(
b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
intr->src[1].ssa, intr->src[2].ssa,
nir_imul_imm(b, calc_unrolled_id(b), state->max_indices),
nir_imul_imm(b, calc_unrolled_id(b), state->info->max_indices),
calc_unrolled_index_id(b),
nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
}
@ -934,7 +926,7 @@ lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
libagx_pad_index_gs(b, load_geometry_param(b, output_index_buffer),
intr->src[0].ssa, intr->src[1].ssa,
calc_unrolled_id(b),
nir_imm_int(b, state_->max_indices));
nir_imm_int(b, state_->info->max_indices));
}
break;
@ -1282,27 +1274,27 @@ evaluate_topology(nir_builder *b, nir_intrinsic_instr *intr, void *data)
* 0, 1, 2, -1, 3, 4, 5, -1, ...
*/
static bool
match_list_topology(struct lower_gs_state *state, uint32_t count)
match_list_topology(struct agx_gs_info *info, uint32_t count)
{
unsigned count_with_restart = count + 1;
/* Must be an integer number of primitives */
if (state->max_indices % count_with_restart)
if (info->max_indices % count_with_restart)
return false;
/* Must match the list topology */
for (unsigned i = 0; i < state->max_indices; ++i) {
for (unsigned i = 0; i < info->max_indices; ++i) {
bool restart = (i % count_with_restart) == count;
uint32_t expected = restart ? -1 : (i - (i / count_with_restart));
if (state->info->topology[i] != expected)
if (info->topology[i] != expected)
return false;
}
/* If we match, rewrite the topology and drop indexing */
state->info->indexed = false;
state->info->mode = u_decomposed_prim(state->info->mode);
state->max_indices = (state->max_indices / count_with_restart) * count;
info->indexed = false;
info->mode = u_decomposed_prim(info->mode);
info->max_indices = (info->max_indices / count_with_restart) * count;
return true;
}
@ -1349,15 +1341,15 @@ optimize_static_topology(struct lower_gs_state *state, nir_shader *gs)
/* Try to pattern match a list topology */
unsigned count = verts_in_output_prim(gs);
if (match_list_topology(state, count))
if (match_list_topology(state->info, count))
return;
/* Because we're instancing, we can always drop the trailing restart index */
state->info->instanced = true;
state->max_indices--;
state->info->max_indices--;
/* Try to pattern match a strip topology */
if (is_strip_topology(state->info->topology, state->max_indices)) {
if (is_strip_topology(state->info->topology, state->info->max_indices)) {
state->info->indexed = false;
return;
}
@ -1456,16 +1448,15 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
*/
for (unsigned i = 0; i < MAX_VERTEX_STREAMS; ++i) {
gs_state.count_index[i] =
(gs_state.static_count[i] < 0) ? gs_state.count_stride_el++ : -1;
(gs_state.static_count[i] < 0) ? info->count_words++ : -1;
}
/* Using the gathered static counts, choose the index buffer stride. */
gs_state.max_indices = calculate_max_indices(
info->max_indices = calculate_max_indices(
gs->info.gs.output_primitive, gs->info.gs.vertices_out,
static_vertices[0], static_primitives[0]);
gs_state.prefix_summing =
gs_state.count_stride_el > 0 && gs->xfb_info != NULL;
info->prefix_sum = info->count_words > 0 && gs->xfb_info != NULL;
if (static_vertices >= 0 && static_primitives >= 0) {
optimize_static_topology(&gs_state, gs);
@ -1487,7 +1478,7 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
NIR_PASS(_, gs, nir_remove_dead_variables, nir_var_function_temp, NULL);
/* If there is any unknown count, we need a geometry count shader */
if (gs_state.count_stride_el > 0)
if (info->count_words > 0)
*gs_count = agx_nir_create_geometry_count_shader(gs, &gs_state);
else
*gs_count = NULL;
@ -1574,9 +1565,6 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
&gs_state, gs->xfb_info, verts_in_output_prim(gs),
gs->info.gs.active_stream_mask, gs->info.gs.invocations);
info->count_words = gs_state.count_stride_el;
info->prefix_sum = gs_state.prefix_summing;
info->max_indices = gs_state.max_indices;
return true;
}