nir/lower_gs_intrinsics: drop stuff added for AGX

AGX now vendors a significantly different version of this pass, so the common
one doesn't need the stuff added for AGX.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Mary Guillemard <mary.guillemard@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35802>
This commit is contained in:
Alyssa Rosenzweig 2025-06-27 13:35:16 -04:00 committed by Marge Bot
parent debd619555
commit d13b321201
3 changed files with 6 additions and 115 deletions

View file

@ -5845,8 +5845,6 @@ typedef enum {
nir_lower_gs_intrinsics_count_primitives = 1 << 1,
nir_lower_gs_intrinsics_count_vertices_per_primitive = 1 << 2,
nir_lower_gs_intrinsics_overwrite_incomplete = 1 << 3,
nir_lower_gs_intrinsics_always_end_primitive = 1 << 4,
nir_lower_gs_intrinsics_count_decomposed_primitives = 1 << 5,
} nir_lower_gs_intrinsics_flags;
bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);

View file

@ -616,16 +616,12 @@ intrinsic("end_primitive", indices=[STREAM_ID])
# Alternatively, drivers may implement these intrinsics, and use
# nir_lower_gs_intrinsics() to convert from the basic intrinsics.
#
# These contain four additional unsigned integer sources:
# These contain two additional unsigned integer sources:
# 1. The total number of vertices emitted so far.
# 2. The number of vertices emitted for the current primitive
# so far if we're counting, otherwise undef.
# 3. The total number of primitives emitted so far.
# 4. The total number of decomposed primitives emitted so far. This counts like
# the PRIMITIVES_GENERATED query: a triangle strip with 5 vertices is counted
# as 3 primitives (not 1).
intrinsic("emit_vertex_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1, 1, 1, 1], indices=[STREAM_ID])
intrinsic("emit_vertex_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
intrinsic("end_primitive_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
# Contains the final total vertex, primitive, and decomposed primitives counts
# in the current GS thread.
intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1, 1], indices=[STREAM_ID])

View file

@ -59,11 +59,9 @@ struct state {
nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS];
nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
nir_variable *decomposed_primitive_count_vars[NIR_MAX_XFB_STREAMS];
bool per_stream;
bool count_prims;
bool count_vtx_per_prim;
bool count_decomposed_prims;
bool overwrite_incomplete;
bool is_points;
bool progress;
@ -89,8 +87,6 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
assert(state->vertex_count_vars[stream] != NULL);
nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_def *count_per_primitive;
nir_def *primitive_count;
nir_def *decomposed_primitive_count;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
@ -99,18 +95,6 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
else
count_per_primitive = nir_undef(b, 1, 32);
if (state->count_prims)
primitive_count = nir_load_var(b, state->primitive_count_vars[stream]);
else
primitive_count = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_primitive_count =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_primitive_count = nir_undef(b, 1, 32);
}
/* Create: if (vertex_count < max_vertices) and insert it.
*
* The new if statement needs to be hooked up to the control flow graph
@ -118,8 +102,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
*/
nir_push_if(b, nir_ilt_imm(b, count, b->shader->info.gs.vertices_out));
nir_emit_vertex_with_counter(b, count, count_per_primitive, primitive_count,
decomposed_primitive_count, stream);
nir_emit_vertex_with_counter(b, count, count_per_primitive, stream);
/* Increment the vertex count by 1 */
nir_store_var(b, state->vertex_count_vars[stream],
@ -135,26 +118,6 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
0x1); /* .x */
}
if (state->count_decomposed_prims) {
nir_variable *vtx_var = state->vtxcnt_per_prim_vars[stream];
nir_def *vtx_per_prim_cnt = state->is_points ? nir_imm_int(b, 1) : nir_load_var(b, vtx_var);
/* We form a new primitive for every vertex emitted after the first
* complete primitive (since we're outputting strips).
*/
unsigned min_verts = nir_verts_in_output_prim(b->shader);
nir_def *new_prim = nir_uge_imm(b, vtx_per_prim_cnt, min_verts);
/* Increment the decomposed primitive count by 1 if we formed a complete
* primitive.
*/
nir_variable *var = state->decomposed_primitive_count_vars[stream];
nir_def *cnt = nir_load_var(b, var);
nir_store_var(b, var,
nir_iadd(b, cnt, nir_b2i32(b, new_prim)),
0x1); /* .x */
}
nir_pop_if(b, NULL);
nir_instr_remove(&intrin->instr);
@ -239,29 +202,13 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
assert(state->vertex_count_vars[stream] != NULL);
nir_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
nir_def *count_per_primitive;
nir_def *primitive_count;
nir_def *decomposed_primitive_count;
if (state->count_vtx_per_prim)
count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
else
count_per_primitive = nir_undef(b, count->num_components, count->bit_size);
if (state->count_prims)
primitive_count = nir_load_var(b, state->primitive_count_vars[stream]);
else
primitive_count = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_primitive_count =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_primitive_count = nir_undef(b, 1, 32);
}
nir_end_primitive_with_counter(b, count, count_per_primitive,
primitive_count,
decomposed_primitive_count, stream);
nir_end_primitive_with_counter(b, count, count_per_primitive, stream);
if (state->count_prims) {
/* Increment the primitive count by 1 */
@ -332,7 +279,6 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
nir_def *vtx_cnt;
nir_def *prim_cnt;
nir_def *decomposed_prim_cnt;
if (state->per_stream && !(shader->info.gs.active_stream_mask & (1 << stream))) {
/* Inactive stream: vertex count is 0, primitive count is 0 or undef. */
@ -340,7 +286,6 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
prim_cnt = state->count_prims || state->is_points
? nir_imm_int(b, 0)
: nir_undef(b, 1, 32);
decomposed_prim_cnt = prim_cnt;
} else {
if (state->overwrite_incomplete)
overwrite_incomplete_primitives(state, stream);
@ -356,48 +301,15 @@ append_set_vertex_and_primitive_count(nir_block *end_block, struct state *state)
prim_cnt = vtx_cnt;
else
prim_cnt = nir_undef(b, 1, 32);
if (state->count_decomposed_prims) {
decomposed_prim_cnt =
nir_load_var(b, state->decomposed_primitive_count_vars[stream]);
} else {
decomposed_prim_cnt = nir_undef(b, 1, 32);
}
}
nir_set_vertex_and_primitive_count(b, vtx_cnt, prim_cnt,
decomposed_prim_cnt, stream);
nir_undef(b, 1, 32), stream);
state->progress = true;
}
}
}
/*
* Append an EndPrimitive intrinsic to the end of the geometry shader. This
* allows the backend to emit primitives only when EndPrimitive is used. If this
* EndPrimitive is not needed, it will be predicated out via
* overwrite_incomplete_primitives.
*/
static void
append_end_primitive(nir_block *end_block, struct state *state)
{
nir_builder *b = state->builder;
/* Only end a primitive if there is a primitive to end */
if (b->shader->info.gs.active_stream_mask == 0)
return;
/* Insert the new intrinsic in all of the predecessors of the end block,
* but before any jump instructions (return).
*/
set_foreach(end_block->predecessors, entry) {
nir_block *pred = (nir_block *)entry->key;
b->cursor = nir_after_block_before_jump(pred);
nir_end_primitive(b);
}
}
/**
* Check to see if there are any blocks that need set_vertex_and_primitive_count
*
@ -445,11 +357,9 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
bool overwrite_incomplete = options & nir_lower_gs_intrinsics_overwrite_incomplete;
bool always_end_primitive_non_points = options & nir_lower_gs_intrinsics_always_end_primitive;
bool count_vtx_per_prim =
overwrite_incomplete ||
(options & nir_lower_gs_intrinsics_count_vertices_per_primitive);
bool count_decomposed_prims = options & nir_lower_gs_intrinsics_count_decomposed_primitives;
bool is_points = shader->info.gs.output_primitive == MESA_PRIM_POINTS;
/* points are always complete primitives with a single vertex, so these are
@ -466,7 +376,6 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
state.progress = false;
state.count_prims = count_primitives;
state.count_vtx_per_prim = count_vtx_per_prim;
state.count_decomposed_prims = count_decomposed_prims;
state.overwrite_incomplete = overwrite_incomplete;
state.per_stream = per_stream;
state.is_points = is_points;
@ -502,13 +411,6 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
/* initialize to 0 */
nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1);
}
if (count_decomposed_prims) {
state.decomposed_primitive_count_vars[i] =
nir_local_variable_create(impl, glsl_uint_type(), "decomposed_primitive_count");
/* initialize to 0 */
nir_store_var(&b, state.decomposed_primitive_count_vars[i],
nir_imm_int(&b, 0), 0x1);
}
} else {
/* If per_stream is false, we only have one counter of each kind which we
* want to use for all streams. Duplicate the counter pointers so all
@ -520,14 +422,9 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
state.primitive_count_vars[i] = state.primitive_count_vars[0];
if (count_vtx_per_prim)
state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0];
if (count_decomposed_prims)
state.decomposed_primitive_count_vars[i] = state.decomposed_primitive_count_vars[0];
}
}
if (always_end_primitive_non_points && !is_points)
append_end_primitive(impl->end_block, &state);
nir_foreach_block_safe(block, impl)
rewrite_intrinsics(block, &state);