ac/nir/ngg: add gs culling

Port from radeonsi.

Cull primitive after GS thread and before final vertex/primitive
export. GS culling is like VS/TES culling which read out saved
vertex positions of a primitive from LDS then call the primitive
culling algorithm to check whether it's visiable or not, only
passed primitives will be exported.

Unlike the VS/TES culling that read vertex index of a primitive
from VGPRs as shader args, GS will set a primitive complete flag
for each last vertex of a primitive in LDS, so that vertex thread
know the previous 1/2/3 vertex can form a primitive and do primitive
culling.

Acked-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17651>
This commit is contained in:
Qiang Yu 2022-06-09 09:11:10 +08:00 committed by Marge Bot
parent b212fd4b1e
commit 1bdeb961bd
3 changed files with 154 additions and 9 deletions

View file

@ -140,7 +140,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
unsigned esgs_ring_lds_bytes,
unsigned gs_out_vtx_bytes,
unsigned gs_total_out_vtx_bytes,
bool provoking_vtx_last);
bool provoking_vtx_last,
bool can_cull);
void
ac_nir_lower_ngg_ms(nir_shader *shader,

View file

@ -100,6 +100,7 @@ typedef struct
bool found_out_vtxcnt[4];
bool output_compile_time_known;
bool provoking_vertex_last;
bool can_cull;
gs_output_info output_info[VARYING_SLOT_MAX];
} lower_ngg_gs_state;
@ -1782,12 +1783,17 @@ lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intri
/* Calculate and store per-vertex primitive flags based on vertex counts:
* - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
* - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
* - bit 2: always 1 (so that we can use it for determining vertex liveness)
* - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
*/
nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
nir_ssa_def *vertex_live_flag = !stream && s->can_cull ?
nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2) :
nir_imm_int(b, 0b100);
nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
nir_ssa_def *complete_flag = nir_b2i32(b, completes_prim);
nir_ssa_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
if (s->num_vertices_per_primitive == 3) {
nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
@ -1987,6 +1993,124 @@ ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_d
return nir_if_phi(b, primflag_0, zero);
}
static void
ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_ssa_def *last_vtxidx, nir_ssa_def *last_vtxptr,
nir_ssa_def *last_vtx_primflag, lower_ngg_gs_state *s,
nir_ssa_def *vtxptr[3])
{
unsigned last_vtx = s->num_vertices_per_primitive - 1;
vtxptr[last_vtx]= last_vtxptr;
bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
nir_ssa_def *is_odd = primitive_is_triangle ?
nir_ubfe(b, last_vtx_primflag, nir_imm_int(b, 1), nir_imm_int(b, 1)) : NULL;
for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
nir_ssa_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
/* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
* CW/CCW order for correct front/back face culling.
*/
if (primitive_is_triangle)
vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
}
}
static nir_ssa_def *
ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_vtxcnt,
nir_ssa_def *out_vtx_lds_addr, nir_ssa_def *out_vtx_primflag_0,
lower_ngg_gs_state *s)
{
/* we haven't enabled point culling, if enabled this function could be further optimized */
assert(s->num_vertices_per_primitive > 1);
/* save the primflag so that we don't need to load it from LDS again */
nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
/* last bit of primflag indicate if this is the final vertex of a primitive */
nir_ssa_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
nir_ssa_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
nir_ssa_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
nir_if *if_prim_enable = nir_push_if(b, prim_enable);
{
/* Calculate the LDS address of every vertex in the current primitive. */
nir_ssa_def *vtxptr[3];
ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
/* Load the positions from LDS. */
nir_ssa_def *pos[3][4];
for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
/* VARYING_SLOT_POS == 0, so base won't count packed location */
pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
pos[i][0] = nir_channel(b, xy, 0);
pos[i][1] = nir_channel(b, xy, 1);
pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
}
nir_ssa_def *accepted = ac_nir_cull_primitive(
b, nir_imm_bool(b, true), pos, s->num_vertices_per_primitive, NULL, NULL);
nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
{
/* clear the primflag if rejected */
nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
.base = s->lds_offs_primflags);
nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
}
nir_pop_if(b, if_rejected);
}
nir_pop_if(b, if_prim_enable);
/* Wait for LDS primflag access done. */
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
/* only dead vertex need a chance to relive */
nir_ssa_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
nir_ssa_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
{
/* get succeeding vertices' primflag to detect this vertex's liveness */
for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
nir_ssa_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
nir_ssa_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
nir_if *if_not_overflow = nir_push_if(b, not_overflow);
{
nir_ssa_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
nir_ssa_def *vtx_primflag =
nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
vtx_primflag = nir_u2u32(b, vtx_primflag);
/* if succeeding vertex is alive end of primitive vertex, need to set current
* thread vertex's liveness flag (bit 2)
*/
nir_ssa_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
nir_ssa_def *vtx_live_flag =
nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
/* update this vertex's primflag */
nir_ssa_def *primflag = nir_load_var(b, primflag_var);
primflag = nir_ior(b, primflag, vtx_live_flag);
nir_store_var(b, primflag_var, primflag, 1);
}
nir_pop_if(b, if_not_overflow);
}
}
nir_pop_if(b, if_update_primflag);
return nir_load_var(b, primflag_var);
}
static void
ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
{
@ -2016,6 +2140,20 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
return;
}
/* cull primitives */
if (s->can_cull) {
nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
/* culling code will update the primflag */
nir_ssa_def *updated_primflag =
ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
out_vtx_primflag_0, s);
nir_pop_if(b, if_cull_en);
out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
}
/* When the output vertex count is not known at compile time:
* There may be gaps between invocations that have live vertices, but NGG hardware
* requires that the invocations that export vertices are packed (ie. compact).
@ -2054,7 +2192,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
unsigned esgs_ring_lds_bytes,
unsigned gs_out_vtx_bytes,
unsigned gs_total_out_vtx_bytes,
bool provoking_vertex_last)
bool provoking_vertex_last,
bool can_cull)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@ -2068,15 +2207,20 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
.lds_offs_primflags = gs_out_vtx_bytes,
.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
.provoking_vertex_last = provoking_vertex_last,
.can_cull = can_cull,
};
unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
shader->info.shared_size = total_lds_bytes;
nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
state.const_out_prmcnt[0] != -1;
if (!can_cull) {
nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
state.const_out_prmcnt, 4u);
state.output_compile_time_known =
state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
state.const_out_prmcnt[0] != -1;
}
if (!state.output_compile_time_known)
state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");

View file

@ -1342,7 +1342,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
assert(info->is_ngg);
NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last);
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false);
} else if (nir->info.stage == MESA_SHADER_MESH) {
bool scratch_ring = false;
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);