agx/nir_lower_gs: optimize static topologies

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Mary Guillemard <mary.guillemard@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34638>
This commit is contained in:
Alyssa Rosenzweig 2025-04-21 09:59:06 -04:00 committed by Marge Bot
parent 3da8939197
commit b9b6828fda
6 changed files with 299 additions and 58 deletions

View file

@ -45,6 +45,8 @@ struct lower_gs_state {
bool rasterizer_discard;
bool prefix_summing;
struct agx_gs_info *info;
};
/* Helpers for loading from the geometry state buffer */
@ -256,10 +258,9 @@ calc_unrolled_id(nir_builder *b)
}
static unsigned
output_vertex_id_stride(nir_shader *gs)
output_vertex_id_pot_stride(const nir_shader *gs)
{
/* round up to power of two for cheap multiply/division */
return util_next_power_of_two(MAX2(gs->info.gs.vertices_out, 1));
return util_next_power_of_two(gs->info.gs.vertices_out);
}
/* Variant of calc_unrolled_id that uses a power-of-two stride for indices. This
@ -274,7 +275,8 @@ output_vertex_id_stride(nir_shader *gs)
static nir_def *
calc_unrolled_index_id(nir_builder *b)
{
unsigned vertex_stride = output_vertex_id_stride(b->shader);
/* We know this is a dynamic topology and hence indexed */
unsigned vertex_stride = output_vertex_id_pot_stride(b->shader);
nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
nir_def *instance = nir_ishl(b, load_instance_id(b), primitives_log2);
@ -396,6 +398,7 @@ agx_nir_create_geometry_count_shader(nir_shader *gs,
}
struct lower_gs_rast_state {
nir_def *raw_instance_id;
nir_def *instance_id, *primitive_id, *output_id;
struct agx_lower_output_to_var_state outputs;
struct agx_lower_output_to_var_state selected;
@ -444,6 +447,10 @@ lower_to_gs_rast(nir_builder *b, nir_intrinsic_instr *intr, void *data)
return true;
case nir_intrinsic_load_instance_id:
/* Don't lower recursively */
if (state->raw_instance_id == &intr->def)
return false;
nir_def_rewrite_uses(&intr->def, state->instance_id);
return true;
@ -584,13 +591,12 @@ strip_side_effect_from_main(nir_builder *b, nir_intrinsic_instr *intr,
* shades each rasterized output vertex in parallel.
*/
static nir_shader *
agx_nir_create_gs_rast_shader(const nir_shader *gs, bool *side_effects_for_rast)
agx_nir_create_gs_rast_shader(const nir_shader *gs, bool *side_effects_for_rast,
const struct lower_gs_state *state)
{
/* Don't muck up the original shader */
nir_shader *shader = nir_shader_clone(NULL, gs);
unsigned max_verts = output_vertex_id_stride(shader);
/* Turn into a vertex shader run only for rasterization. Transform feedback
* was handled in the prepass.
*/
@ -615,18 +621,40 @@ agx_nir_create_gs_rast_shader(const nir_shader *gs, bool *side_effects_for_rast)
if (shader->info.gs.output_primitive != MESA_PRIM_POINTS)
shader->info.outputs_written &= ~VARYING_BIT_PSIZ;
/* See calc_unrolled_index_id */
nir_def *output_id, *unrolled;
if (state->info->instanced) {
/* vertex ID = ID within the primitive, instance ID = unrolled prim ID */
output_id = nir_load_vertex_id(b);
unrolled = nir_load_instance_id(b);
} else {
/* vertex ID = unrolled (see calc_unrolled_index_id), no instancing */
nir_def *raw_id = nir_load_vertex_id(b);
nir_def *output_id = nir_umod_imm(b, raw_id, max_verts);
nir_def *unrolled = nir_udiv_imm(b, raw_id, max_verts);
unsigned stride = state->info->indexed ? output_vertex_id_pot_stride(gs)
: MAX2(state->max_indices, 1);
output_id = nir_umod_imm(b, raw_id, stride);
unrolled = nir_udiv_imm(b, raw_id, stride);
}
/* If we are indexed, we know indices are sparse and rounded up to powers of
* two, so we can just shift & mask to pick apart. Otherwise, we fall back on
* a slower integer division.
*/
nir_def *instance_id, *primitive_id;
if (state->info->indexed) {
nir_def *primitives_log2 = load_geometry_param(b, primitives_log2);
nir_def *instance_id = nir_ushr(b, unrolled, primitives_log2);
nir_def *primitive_id = nir_iand(
instance_id = nir_ushr(b, unrolled, primitives_log2);
primitive_id = nir_iand(
b, unrolled,
nir_iadd_imm(b, nir_ishl(b, nir_imm_int(b, 1), primitives_log2), -1));
} else {
nir_def *primitives = load_geometry_param(b, gs_grid[0]);
instance_id = nir_udiv(b, unrolled, primitives);
primitive_id = nir_umod(b, unrolled, primitives);
}
struct lower_gs_rast_state rast_state = {
.raw_instance_id = unrolled,
.instance_id = instance_id,
.primitive_id = primitive_id,
.output_id = output_id,
@ -891,13 +919,16 @@ static bool
lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
{
b->cursor = nir_before_instr(&intr->instr);
struct lower_gs_state *state_ = state;
switch (intr->intrinsic) {
case nir_intrinsic_set_vertex_and_primitive_count: {
if (!state_->info->dynamic_topology)
break;
/* Points write their index buffer here, other primitives write on end. We
* also pad the index buffer here for the rasterization stream.
*/
struct lower_gs_state *state_ = state;
if (b->shader->info.gs.output_primitive == MESA_PRIM_POINTS) {
lower_end_primitive(b, intr, state);
}
@ -913,6 +944,10 @@ lower_gs_instr(nir_builder *b, nir_intrinsic_instr *intr, void *state)
}
case nir_intrinsic_end_primitive_with_counter: {
/* If the topology is static, we use the static index buffer instead. */
if (!state_->info->dynamic_topology)
break;
unsigned min = verts_in_output_prim(b->shader);
/* We only write out complete primitives */
@ -1196,6 +1231,141 @@ calculate_max_indices(enum mesa_prim prim, unsigned verts, signed static_verts,
return verts + (verts / mesa_vertices_per_prim(prim));
}
static bool
evaluate_topology(nir_builder *b, nir_intrinsic_instr *intr, void *data)
{
bool points = b->shader->info.gs.output_primitive == MESA_PRIM_POINTS;
bool end_prim = intr->intrinsic == nir_intrinsic_end_primitive_with_counter;
bool set_prim =
intr->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
struct lower_gs_state *ctx = data;
if (!(set_prim && points) && !end_prim)
return false;
assert(!(end_prim && points) && "should have been deleted");
/* Only consider the rasterization stream. */
if (nir_intrinsic_stream_id(intr) != 0)
return false;
/* All end primitives must be executed exactly once. That happens if
* everything is in the start block.
*
* Strictly we could relax this (to handle if-statements interleaved with
* other stuff).
*/
if (intr->instr.block != nir_start_block(b->impl)) {
ctx->info->dynamic_topology = true;
return false;
}
/* The topology must be static */
if (!nir_src_is_const(intr->src[0]) || !nir_src_is_const(intr->src[1]) ||
!nir_src_is_const(intr->src[2])) {
ctx->info->dynamic_topology = true;
return false;
}
unsigned min = verts_in_output_prim(b->shader);
if (nir_src_as_uint(intr->src[1]) >= min) {
_libagx_end_primitive(ctx->info->topology, nir_src_as_uint(intr->src[0]),
nir_src_as_uint(intr->src[1]),
nir_src_as_uint(intr->src[2]), 0, 0, !points);
}
return false;
}
/*
* Pattern match the index buffer with restart against a list topology:
*
* 0, 1, 2, -1, 3, 4, 5, -1, ...
*/
static bool
match_list_topology(struct lower_gs_state *state, uint32_t count)
{
unsigned count_with_restart = count + 1;
/* Must be an integer number of primitives */
if (state->max_indices % count_with_restart)
return false;
/* Must match the list topology */
for (unsigned i = 0; i < state->max_indices; ++i) {
bool restart = (i % count_with_restart) == count;
uint32_t expected = restart ? -1 : (i - (i / count_with_restart));
if (state->info->topology[i] != expected)
return false;
}
/* If we match, rewrite the topology and drop indexing */
state->info->indexed = false;
state->info->mode = u_decomposed_prim(state->info->mode);
state->max_indices = (state->max_indices / count_with_restart) * count;
return true;
}
static bool
is_strip_topology(uint32_t *indices, uint32_t index_count)
{
for (unsigned i = 0; i < index_count; ++i) {
if (indices[i] != i)
return false;
}
return true;
}
/*
* To handle the general case of geometry shaders generating dynamic topologies,
* we translate geometry shaders into compute shaders that write an index
* buffer. In practice, many geometry shaders have static topologies that can be
* determined at compile-time. By identifying these, we can avoid the dynamic
* index buffer allocation and writes. optimize_static_topology tries to
* statically determine the topology, then translating it to one of:
*
* 1. Non-indexed line/triangle lists without instancing.
* 2. Non-indexed line/triangle strips, instanced per input primitive.
* 3. Static index buffer, instanced per input primitive.
*
* If the geometry shader has no side effect, the only job of the compute shader
* is writing this index buffer, so this optimization effectively eliminates the
* compute dispatch entirely. That means simple VS+GS pipelines turn into simple
* VS(compute) + GS(vertex) sequences without auxiliary programs.
*/
static void
optimize_static_topology(struct lower_gs_state *state, nir_shader *gs)
{
nir_shader_intrinsics_pass(gs, evaluate_topology, nir_metadata_all, state);
if (state->info->dynamic_topology)
return;
/* Points are always lists, we never have restarts/instancing */
if (gs->info.gs.output_primitive == MESA_PRIM_POINTS) {
state->info->indexed = false;
return;
}
/* Try to pattern match a list topology */
unsigned count = verts_in_output_prim(gs);
if (match_list_topology(state, count))
return;
/* Because we're instancing, we can always drop the trailing restart index */
state->info->instanced = true;
state->max_indices--;
/* Try to pattern match a strip topology */
if (is_strip_topology(state->info->topology, state->max_indices)) {
state->info->indexed = false;
return;
}
}
bool
agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
nir_shader **gs_copy, nir_shader **pre_gs,
@ -1271,6 +1441,13 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
*/
struct lower_gs_state gs_state = {
.rasterizer_discard = rasterizer_discard,
.info = info,
};
*info = (struct agx_gs_info){
.mode = gs->info.gs.output_primitive,
.xfb = gs->xfb_info != NULL,
.indexed = true,
};
int static_vertices[4] = {0}, static_primitives[4] = {0};
@ -1293,8 +1470,15 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
gs_state.prefix_summing =
gs_state.count_stride_el > 0 && gs->xfb_info != NULL;
if (static_vertices >= 0 && static_primitives >= 0) {
optimize_static_topology(&gs_state, gs);
} else {
info->dynamic_topology = true;
}
bool side_effects_for_rast = false;
*gs_copy = agx_nir_create_gs_rast_shader(gs, &side_effects_for_rast);
*gs_copy =
agx_nir_create_gs_rast_shader(gs, &side_effects_for_rast, &gs_state);
NIR_PASS(_, gs, nir_shader_intrinsics_pass, lower_id,
nir_metadata_control_flow, NULL);
@ -1393,15 +1577,9 @@ agx_nir_lower_gs(nir_shader *gs, bool rasterizer_discard, nir_shader **gs_count,
&gs_state, gs->xfb_info, verts_in_output_prim(gs),
gs->info.gs.active_stream_mask, gs->info.gs.invocations);
/* Signal what primitive we want to draw the GS Copy VS with */
*info = (struct agx_gs_info){
.mode = gs->info.gs.output_primitive,
.count_words = gs_state.count_stride_el,
.prefix_sum = gs_state.prefix_summing,
.max_indices = gs_state.max_indices,
.xfb = gs->xfb_info != NULL,
};
info->count_words = gs_state.count_stride_el;
info->prefix_sum = gs_state.prefix_summing;
info->max_indices = gs_state.max_indices;
return true;
}

View file

@ -45,6 +45,18 @@ struct agx_gs_info {
/* Whether a prefix sum is required on the count outputs. Implies xfb */
bool prefix_sum;
/* Whether we need to dynamically allocate an index buffer. */
bool dynamic_topology;
/* Whether the topology requires an index buffer */
bool indexed;
/* Whether the topology requires hardware instancing */
bool instanced;
/* Static topology used if dynamic_topology is false. */
uint32_t topology[384];
};
bool agx_nir_lower_gs(struct nir_shader *gs, bool rasterizer_discard,

View file

@ -583,7 +583,8 @@ libagx_gs_setup_indirect(
uint32_t index_size_B /* 0 if no index bffer */,
uint32_t index_buffer_range_el,
uint32_t prim /* Input primitive type, enum mesa_prim */,
int is_prefix_summing, uint indices_per_in_prim)
int is_prefix_summing, uint indices_per_in_prim, int dynamic_topology,
int instanced)
{
/* Determine the (primitives, instances) grid size. */
uint vertex_count = draw[0];
@ -636,17 +637,21 @@ libagx_gs_setup_indirect(
/* Allocate the index buffer and write the draw consuming it */
global VkDrawIndexedIndirectCommand *cmd = (global void *)p->indirect_desc;
uint count = p->input_primitives * indices_per_in_prim;
uint index_buffer_offset_B = agx_heap_alloc_nonatomic_offs(state, count * 4);
uint count = (instanced ? 1 : p->input_primitives) * indices_per_in_prim;
uint index_buffer_offset_B = 0;
*cmd = (VkDrawIndexedIndirectCommand){
.indexCount = count,
.instanceCount = 1,
.firstIndex = index_buffer_offset_B / 4,
};
if (dynamic_topology) {
index_buffer_offset_B = agx_heap_alloc_nonatomic_offs(state, count * 4);
p->output_index_buffer =
(global uint *)(state->heap + index_buffer_offset_B);
}
*cmd = (VkDrawIndexedIndirectCommand){
.indexCount = count,
.instanceCount = instanced ? p->input_primitives : 1,
.firstIndex = index_buffer_offset_B / 4,
};
}
/*

View file

@ -462,6 +462,7 @@ struct hk_cmd_buffer {
uint64_t geom_indirect;
uint64_t geom_index_buffer;
uint32_t geom_index_count;
uint32_t geom_instance_count;
/* Does the command buffer use the geometry heap? */
bool uses_heap;

View file

@ -1092,6 +1092,7 @@ hk_rast_prim(struct hk_cmd_buffer *cmd)
static uint64_t
hk_upload_geometry_params(struct hk_cmd_buffer *cmd, struct agx_draw draw)
{
struct hk_device *dev = hk_cmd_buffer_device(cmd);
struct hk_descriptor_state *desc = &cmd->state.gfx.descriptors;
struct vk_dynamic_graphics_state *dyn = &cmd->vk.dynamic_graphics_state;
struct hk_graphics_state *gfx = &cmd->state.gfx;
@ -1156,6 +1157,8 @@ hk_upload_geometry_params(struct hk_cmd_buffer *cmd, struct agx_draw draw)
params.indirect_desc = cmd->geom_indirect;
params.vs_grid[2] = params.gs_grid[2] = 1;
cmd->geom_index_buffer = dev->heap->va->addr;
cmd->geom_index_count = dev->heap->size;
} else {
uint32_t verts = draw.b.count[0], instances = draw.b.count[1];
@ -1170,14 +1173,28 @@ hk_upload_geometry_params(struct hk_cmd_buffer *cmd, struct agx_draw draw)
params.count_buffer = hk_pool_alloc(cmd, size, 4).gpu;
}
if (count->info.gs.instanced) {
cmd->geom_index_count = count->info.gs.max_indices;
cmd->geom_instance_count = params.input_primitives;
} else {
cmd->geom_index_count =
params.input_primitives * count->info.gs.max_indices;
cmd->geom_instance_count = 1;
}
if (count->info.gs.dynamic_topology) {
params.output_index_buffer =
hk_pool_alloc(cmd, cmd->geom_index_count * 4, 4).gpu;
cmd->geom_index_buffer = params.output_index_buffer;
}
}
if (count->info.gs.indexed && !count->info.gs.dynamic_topology) {
cmd->geom_index_buffer = hk_pool_upload(
cmd, count->info.gs.topology, count->info.gs.max_indices * 4, 4);
}
desc->root_dirty = true;
return hk_pool_upload(cmd, &params, sizeof(params), 8);
@ -1434,6 +1451,8 @@ hk_launch_gs_prerast(struct hk_cmd_buffer *cmd, struct hk_cs *cs,
.prim = mode,
.is_prefix_summing = count->info.gs.prefix_sum,
.indices_per_in_prim = count->info.gs.max_indices,
.dynamic_topology = count->info.gs.dynamic_topology,
.instanced = count->info.gs.instanced,
};
if (cmd->state.gfx.shaders[MESA_SHADER_TESS_EVAL]) {
@ -1509,13 +1528,24 @@ hk_launch_gs_prerast(struct hk_cmd_buffer *cmd, struct hk_cs *cs,
hk_dispatch_with_local_size(cmd, cs, main, grid_gs, agx_workgroup(1, 1, 1));
if (agx_is_indirect(draw.b)) {
return agx_draw_indexed_indirect(cmd->geom_indirect, dev->heap->va->addr,
dev->heap->size, AGX_INDEX_SIZE_U32,
true);
if (count->info.gs.indexed) {
return agx_draw_indexed_indirect(
cmd->geom_indirect, cmd->geom_index_buffer, cmd->geom_index_count,
AGX_INDEX_SIZE_U32, true);
} else {
return agx_draw_indexed(cmd->geom_index_count, 1, 0, 0, 0,
return agx_draw_indirect(cmd->geom_indirect);
}
} else {
if (count->info.gs.indexed) {
return agx_draw_indexed(
cmd->geom_index_count, cmd->geom_instance_count, 0, 0, 0,
cmd->geom_index_buffer, cmd->geom_index_count * 4,
AGX_INDEX_SIZE_U32, true);
} else {
return (struct agx_draw){
.b = agx_3d(cmd->geom_index_count, cmd->geom_instance_count, 1),
};
}
}
}

View file

@ -4073,6 +4073,7 @@ agx_batch_geometry_params(struct agx_batch *batch, uint64_t input_index_buffer,
params.input_buffer = addr;
}
if (batch->ctx->gs->gs.dynamic_topology) {
unsigned idx_size =
params.input_primitives * batch->ctx->gs->gs.max_indices;
@ -4082,6 +4083,7 @@ agx_batch_geometry_params(struct agx_batch *batch, uint64_t input_index_buffer,
.gpu;
batch->geom_index = params.output_index_buffer;
}
}
return agx_pool_upload_aligned_with_bo(&batch->pool, &params, sizeof(params),
8, &batch->geom_params_bo);
@ -4155,6 +4157,8 @@ agx_launch_gs_prerast(struct agx_batch *batch,
.prim = info->mode,
.is_prefix_summing = gs->gs.prefix_sum,
.indices_per_in_prim = gs->gs.max_indices,
.instanced = gs->gs.instanced,
.dynamic_topology = gs->gs.dynamic_topology,
};
libagx_gs_setup_indirect_struct(batch, agx_1d(1), AGX_BARRIER_ALL, gsi);
@ -5147,13 +5151,14 @@ agx_draw_vbo(struct pipe_context *pctx, const struct pipe_draw_info *info,
/* Setup to rasterize the GS results */
info_gs = (struct pipe_draw_info){
.mode = ctx->gs->gs.mode,
.index_size = 4,
.primitive_restart = true,
.index_size = ctx->gs->gs.indexed ? 4 : 0,
.primitive_restart = ctx->gs->gs.indexed,
.restart_index = ~0,
.index.resource = &index_rsrc.base,
.instance_count = 1,
};
unsigned unrolled_prims = 0;
if (indirect) {
indirect_gs = (struct pipe_draw_indirect_info){
.draw_count = 1,
@ -5163,15 +5168,20 @@ agx_draw_vbo(struct pipe_context *pctx, const struct pipe_draw_info *info,
indirect = &indirect_gs;
} else {
unsigned unrolled_prims =
bool instanced = ctx->gs->gs.instanced;
unrolled_prims =
u_decomposed_prims_for_vertices(info->mode, draws->count) *
info->instance_count;
draw_gs = (struct pipe_draw_start_count_bias){
.count = ctx->gs->gs.max_indices * unrolled_prims,
.count = ctx->gs->gs.max_indices * (instanced ? 1 : unrolled_prims),
};
draws = &draw_gs;
if (ctx->gs->gs.instanced) {
info_gs.instance_count = unrolled_prims;
}
}
info = &info_gs;
@ -5180,8 +5190,13 @@ agx_draw_vbo(struct pipe_context *pctx, const struct pipe_draw_info *info,
batch->reduced_prim = u_reduced_prim(info->mode);
ctx->dirty |= AGX_DIRTY_PRIM;
if (ctx->gs->gs.dynamic_topology) {
ib = batch->geom_index;
ib_extent = index_rsrc.bo->size - (batch->geom_index - ib);
} else if (ctx->gs->gs.indexed) {
ib_extent = ctx->gs->gs.max_indices * 4;
ib = agx_pool_upload(&batch->pool, ctx->gs->gs.topology, ib_extent);
}
/* We need to reemit geometry descriptors since the txf sampler may change
* between the GS prepass and the GS rast program.