libagx: do not use prefix sums for GS index buffer

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33901>
This commit is contained in:
Alyssa Rosenzweig 2025-01-27 15:10:42 -05:00 committed by Marge Bot
parent 4d2ab1d92c
commit afb53c82bc
4 changed files with 67 additions and 24 deletions

View file

@ -35,6 +35,9 @@ struct lower_gs_state {
int static_count[GS_NUM_COUNTERS][MAX_VERTEX_STREAMS];
nir_variable *outputs[NUM_TOTAL_VARYING_SLOTS][MAX_PRIM_OUT_SIZE];
/* Per-input primitive stride of the output index buffer */
unsigned max_indices;
/* The count buffer contains `count_stride_el` 32-bit words in a row for each
* input primitive, for `input_primitives * count_stride_el * 4` total bytes.
*/
@ -772,8 +775,7 @@ lower_end_primitive(nir_builder *b, nir_intrinsic_instr *intr,
libagx_end_primitive(
b, load_geometry_param(b, output_index_buffer), intr->src[0].ssa,
intr->src[1].ssa, intr->src[2].ssa,
previous_vertices(b, state, 0, calc_unrolled_id(b)),
previous_primitives(b, state, 0, calc_unrolled_id(b)),
nir_imul_imm(b, calc_unrolled_id(b), state->max_indices),
calc_unrolled_index_id(b),
nir_imm_bool(b, b->shader->info.gs.output_primitive != MESA_PRIM_POINTS));
}
@ -915,15 +917,24 @@ lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
b->cursor = nir_before_instr(&intr->instr);
switch (intr->intrinsic) {
case nir_intrinsic_set_vertex_and_primitive_count:
/* This instruction is mostly for the count shader, so just remove. But
* for points, we write the index buffer here so the rast shader can map.
case nir_intrinsic_set_vertex_and_primitive_count: {
/* Points write their index buffer here, other primitives write on end. We
* also pad the index buffer here for the rasterization stream.
*/
struct lower_gs_state *state_ = state;
if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
lower_end_primitive(b, intr, state);
}
if (nir_intrinsic_stream_id(intr) == 0 && !state_->rasterizer_discard) {
libagx_pad_index_gs(b, load_geometry_param(b, output_index_buffer),
intr->src[0].ssa, intr->src[1].ssa,
calc_unrolled_id(b),
nir_imm_int(b, state_->max_indices));
}
break;
}
case nir_intrinsic_end_primitive_with_counter: {
unsigned min = verts_in_output_prim(b->shader);
@ -978,7 +989,8 @@ collect_components(nir_builder *b, nir_intrinsic_instr *intr, void *data)
static nir_shader *
agx_nir_create_pre_gs(struct lower_gs_state *state, bool indexed, bool restart,
struct nir_xfb_info *xfb, unsigned vertices_per_prim,
uint8_t streams, unsigned invocations)
uint8_t streams, unsigned invocations,
unsigned index_buffer_allocation)
{
nir_builder b_ = nir_builder_init_simple_shader(
MESA_SHADER_COMPUTE, &agx_nir_options, "Pre-GS patch up");
@ -991,9 +1003,7 @@ agx_nir_create_pre_gs(struct lower_gs_state *state, bool indexed, bool restart,
if (!state->rasterizer_discard) {
libagx_build_gs_draw(
b, nir_load_geometry_param_buffer_agx(b),
previous_vertices(b, state, 0, unrolled_in_prims),
restart ? previous_primitives(b, state, 0, unrolled_in_prims)
: nir_imm_int(b, 0));
nir_imul_imm(b, unrolled_in_prims, index_buffer_allocation));
}
/* Determine the number of primitives generated in each stream */
@ -1197,6 +1207,29 @@ agx_nir_lower_gs_instancing(nir_shader *gs)
nir_metadata_control_flow, index);
}
static unsigned
calculate_max_indices(enum mesa_prim prim, unsigned verts, signed static_verts,
signed static_prims)
{
/* We always have a static max_vertices, but we might have a tighter bound.
* Use it if we have one
*/
if (static_verts >= 0) {
verts = MIN2(verts, static_verts);
}
/* Points do not need primitive count added. Other topologies do. If we have
* a static primitive count, we use that. Otherwise, we use a worst case
* estimate that primitives are emitted one-by-one.
*/
if (prim == MESA_PRIM_POINTS)
return verts;
else if (static_prims >= 0)
return verts + static_prims;
else
return verts + (verts / mesa_vertices_per_prim(prim));
}
bool
agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
nir_shader **gs_copy, nir_shader **pre_gs,
@ -1289,6 +1322,12 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
}
}
/* Using the gathered static counts, choose the index buffer stride. */
gs_state.max_indices = calculate_max_indices(
gs->info.gs.output_primitive, gs->info.gs.vertices_out,
gs_state.static_count[GS_COUNTER_VERTICES][0],
gs_state.static_count[GS_COUNTER_PRIMITIVES][0]);
bool side_effects_for_rast = false;
*gs_copy = agx_nir_create_gs_rast_shader(gs, &side_effects_for_rast);
@ -1388,7 +1427,7 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
*pre_gs = agx_nir_create_pre_gs(
&gs_state, true, gs->info.gs.output_primitive != MESA_PRIM_POINTS,
gs->xfb_info, verts_in_output_prim(gs), gs->info.gs.active_stream_mask,
gs->info.gs.invocations);
gs->info.gs.invocations, gs_state.max_indices);
/* Signal what primitive we want to draw the GS Copy VS with */
*out_mode = gs->info.gs.output_primitive;

View file

@ -566,8 +566,7 @@ libagx_setup_xfb_buffer(global struct agx_geometry_params *p, uint i)
*/
void
libagx_end_primitive(global int *index_buffer, uint total_verts,
uint verts_in_prim, uint total_prims,
uint invocation_vertex_base, uint invocation_prim_base,
uint verts_in_prim, uint total_prims, uint index_offs,
uint geometry_base, bool restart)
{
/* Previous verts/prims are from previous invocations plus earlier
@ -575,14 +574,15 @@ libagx_end_primitive(global int *index_buffer, uint total_verts,
* subtract the count for this prim from the inclusive sum NIR gives us.
*/
uint previous_verts_in_invoc = (total_verts - verts_in_prim);
uint previous_verts = invocation_vertex_base + previous_verts_in_invoc;
uint previous_prims = restart ? invocation_prim_base + (total_prims - 1) : 0;
uint previous_verts = previous_verts_in_invoc;
uint previous_prims = restart ? (total_prims - 1) : 0;
/* The indices are encoded as: (unrolled ID * output vertices) + vertex. */
uint index_base = geometry_base + previous_verts_in_invoc;
/* Index buffer contains 1 index for each vertex and 1 for each prim */
global int *out = &index_buffer[previous_verts + previous_prims];
global int *out =
&index_buffer[index_offs + previous_verts + previous_prims];
/* Write out indices for the strip */
for (uint i = 0; i < verts_in_prim; ++i) {
@ -594,15 +594,20 @@ libagx_end_primitive(global int *index_buffer, uint total_verts,
}
void
libagx_build_gs_draw(global struct agx_geometry_params *p, uint vertices,
uint primitives)
libagx_pad_index_gs(global int *index_buffer, uint total_verts,
uint total_prims, uint id, uint alloc)
{
for (uint i = total_verts + total_prims; i < alloc; ++i) {
index_buffer[(id * alloc) + i] = -1;
}
}
void
libagx_build_gs_draw(global struct agx_geometry_params *p, uint indices)
{
global uint *descriptor = p->indirect_desc;
global struct agx_geometry_state *state = p->state;
/* Setup the indirect draw descriptor */
uint indices = vertices + primitives; /* includes restart indices */
/* Allocate the index buffer */
uint index_buffer_offset_B = state->heap_bottom;
p->output_index_buffer =
@ -610,6 +615,7 @@ libagx_build_gs_draw(global struct agx_geometry_params *p, uint vertices,
state->heap_bottom += (indices * 4);
assert(state->heap_bottom < state->heap_size);
/* Setup the indirect draw descriptor */
descriptor[0] = indices; /* count */
descriptor[1] = 1; /* instance count */
descriptor[2] = index_buffer_offset_B / 4; /* start */

View file

@ -1487,10 +1487,8 @@ hk_launch_gs_prerast(struct hk_cmd_buffer *cmd, struct hk_cs *cs,
/* Pre-rast geometry shader */
hk_dispatch_with_local_size(cmd, cs, main, grid_gs, agx_workgroup(1, 1, 1));
bool restart = cmd->state.gfx.topology != AGX_PRIMITIVE_POINTS;
return agx_draw_indexed_indirect(cmd->geom_indirect, dev->heap->va->addr,
dev->heap->size, AGX_INDEX_SIZE_U32,
restart);
dev->heap->size, AGX_INDEX_SIZE_U32, true);
}
static struct agx_draw

View file

@ -5076,7 +5076,7 @@ agx_draw_vbo(struct pipe_context *pctx, const struct pipe_draw_info *info,
info_gs = (struct pipe_draw_info){
.mode = ctx->gs->gs_output_mode,
.index_size = 4,
.primitive_restart = ctx->gs->gs_output_mode != MESA_PRIM_POINTS,
.primitive_restart = true,
.restart_index = ~0,
.index.resource = ctx->heap,
.instance_count = 1,