mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-25 04:20:08 +01:00
ac/nir/ngg: add gs culling
Port from radeonsi. Cull primitive after GS thread and before final vertex/primitive export. GS culling is like VS/TES culling which read out saved vertex positions of a primitive from LDS then call the primitive culling algorithm to check whether it's visiable or not, only passed primitives will be exported. Unlike the VS/TES culling that read vertex index of a primitive from VGPRs as shader args, GS will set a primitive complete flag for each last vertex of a primitive in LDS, so that vertex thread know the previous 1/2/3 vertex can form a primitive and do primitive culling. Acked-by: Marek Olšák <marek.olsak@amd.com> Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Signed-off-by: Qiang Yu <yuq825@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17651>
This commit is contained in:
parent
b212fd4b1e
commit
1bdeb961bd
3 changed files with 154 additions and 9 deletions
|
|
@ -140,7 +140,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
|||
unsigned esgs_ring_lds_bytes,
|
||||
unsigned gs_out_vtx_bytes,
|
||||
unsigned gs_total_out_vtx_bytes,
|
||||
bool provoking_vtx_last);
|
||||
bool provoking_vtx_last,
|
||||
bool can_cull);
|
||||
|
||||
void
|
||||
ac_nir_lower_ngg_ms(nir_shader *shader,
|
||||
|
|
|
|||
|
|
@ -100,6 +100,7 @@ typedef struct
|
|||
bool found_out_vtxcnt[4];
|
||||
bool output_compile_time_known;
|
||||
bool provoking_vertex_last;
|
||||
bool can_cull;
|
||||
gs_output_info output_info[VARYING_SLOT_MAX];
|
||||
} lower_ngg_gs_state;
|
||||
|
||||
|
|
@ -1782,12 +1783,17 @@ lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intri
|
|||
/* Calculate and store per-vertex primitive flags based on vertex counts:
|
||||
* - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
|
||||
* - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
|
||||
* - bit 2: always 1 (so that we can use it for determining vertex liveness)
|
||||
* - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
|
||||
*/
|
||||
|
||||
nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
|
||||
nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
|
||||
nir_ssa_def *vertex_live_flag = !stream && s->can_cull ?
|
||||
nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2) :
|
||||
nir_imm_int(b, 0b100);
|
||||
|
||||
nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
|
||||
nir_ssa_def *complete_flag = nir_b2i32(b, completes_prim);
|
||||
|
||||
nir_ssa_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
|
||||
if (s->num_vertices_per_primitive == 3) {
|
||||
nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
|
||||
prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
|
||||
|
|
@ -1987,6 +1993,124 @@ ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_d
|
|||
return nir_if_phi(b, primflag_0, zero);
|
||||
}
|
||||
|
||||
static void
|
||||
ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_ssa_def *last_vtxidx, nir_ssa_def *last_vtxptr,
|
||||
nir_ssa_def *last_vtx_primflag, lower_ngg_gs_state *s,
|
||||
nir_ssa_def *vtxptr[3])
|
||||
{
|
||||
unsigned last_vtx = s->num_vertices_per_primitive - 1;
|
||||
vtxptr[last_vtx]= last_vtxptr;
|
||||
|
||||
bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
|
||||
nir_ssa_def *is_odd = primitive_is_triangle ?
|
||||
nir_ubfe(b, last_vtx_primflag, nir_imm_int(b, 1), nir_imm_int(b, 1)) : NULL;
|
||||
|
||||
for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
|
||||
nir_ssa_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
|
||||
|
||||
/* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
|
||||
* CW/CCW order for correct front/back face culling.
|
||||
*/
|
||||
if (primitive_is_triangle)
|
||||
vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
|
||||
|
||||
vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
|
||||
}
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_vtxcnt,
|
||||
nir_ssa_def *out_vtx_lds_addr, nir_ssa_def *out_vtx_primflag_0,
|
||||
lower_ngg_gs_state *s)
|
||||
{
|
||||
/* we haven't enabled point culling, if enabled this function could be further optimized */
|
||||
assert(s->num_vertices_per_primitive > 1);
|
||||
|
||||
/* save the primflag so that we don't need to load it from LDS again */
|
||||
nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
|
||||
nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
|
||||
|
||||
/* last bit of primflag indicate if this is the final vertex of a primitive */
|
||||
nir_ssa_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
|
||||
nir_ssa_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
|
||||
nir_ssa_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
|
||||
|
||||
nir_if *if_prim_enable = nir_push_if(b, prim_enable);
|
||||
{
|
||||
/* Calculate the LDS address of every vertex in the current primitive. */
|
||||
nir_ssa_def *vtxptr[3];
|
||||
ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
|
||||
|
||||
/* Load the positions from LDS. */
|
||||
nir_ssa_def *pos[3][4];
|
||||
for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
|
||||
/* VARYING_SLOT_POS == 0, so base won't count packed location */
|
||||
pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
|
||||
nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
|
||||
pos[i][0] = nir_channel(b, xy, 0);
|
||||
pos[i][1] = nir_channel(b, xy, 1);
|
||||
|
||||
pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
|
||||
pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
|
||||
}
|
||||
|
||||
nir_ssa_def *accepted = ac_nir_cull_primitive(
|
||||
b, nir_imm_bool(b, true), pos, s->num_vertices_per_primitive, NULL, NULL);
|
||||
|
||||
nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
|
||||
{
|
||||
/* clear the primflag if rejected */
|
||||
nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
|
||||
.base = s->lds_offs_primflags);
|
||||
|
||||
nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
|
||||
}
|
||||
nir_pop_if(b, if_rejected);
|
||||
}
|
||||
nir_pop_if(b, if_prim_enable);
|
||||
|
||||
/* Wait for LDS primflag access done. */
|
||||
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_scope = NIR_SCOPE_WORKGROUP,
|
||||
.memory_semantics = NIR_MEMORY_ACQ_REL,
|
||||
.memory_modes = nir_var_mem_shared);
|
||||
|
||||
/* only dead vertex need a chance to relive */
|
||||
nir_ssa_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
|
||||
nir_ssa_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
|
||||
nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
|
||||
{
|
||||
/* get succeeding vertices' primflag to detect this vertex's liveness */
|
||||
for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
|
||||
nir_ssa_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
|
||||
nir_ssa_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
|
||||
nir_if *if_not_overflow = nir_push_if(b, not_overflow);
|
||||
{
|
||||
nir_ssa_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
|
||||
nir_ssa_def *vtx_primflag =
|
||||
nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
|
||||
vtx_primflag = nir_u2u32(b, vtx_primflag);
|
||||
|
||||
/* if succeeding vertex is alive end of primitive vertex, need to set current
|
||||
* thread vertex's liveness flag (bit 2)
|
||||
*/
|
||||
nir_ssa_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
|
||||
nir_ssa_def *vtx_live_flag =
|
||||
nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
|
||||
|
||||
/* update this vertex's primflag */
|
||||
nir_ssa_def *primflag = nir_load_var(b, primflag_var);
|
||||
primflag = nir_ior(b, primflag, vtx_live_flag);
|
||||
nir_store_var(b, primflag_var, primflag, 1);
|
||||
}
|
||||
nir_pop_if(b, if_not_overflow);
|
||||
}
|
||||
}
|
||||
nir_pop_if(b, if_update_primflag);
|
||||
|
||||
return nir_load_var(b, primflag_var);
|
||||
}
|
||||
|
||||
static void
|
||||
ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
|
||||
{
|
||||
|
|
@ -2016,6 +2140,20 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
|
|||
return;
|
||||
}
|
||||
|
||||
/* cull primitives */
|
||||
if (s->can_cull) {
|
||||
nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
|
||||
|
||||
/* culling code will update the primflag */
|
||||
nir_ssa_def *updated_primflag =
|
||||
ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
|
||||
out_vtx_primflag_0, s);
|
||||
|
||||
nir_pop_if(b, if_cull_en);
|
||||
|
||||
out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
|
||||
}
|
||||
|
||||
/* When the output vertex count is not known at compile time:
|
||||
* There may be gaps between invocations that have live vertices, but NGG hardware
|
||||
* requires that the invocations that export vertices are packed (ie. compact).
|
||||
|
|
@ -2054,7 +2192,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
|||
unsigned esgs_ring_lds_bytes,
|
||||
unsigned gs_out_vtx_bytes,
|
||||
unsigned gs_total_out_vtx_bytes,
|
||||
bool provoking_vertex_last)
|
||||
bool provoking_vertex_last,
|
||||
bool can_cull)
|
||||
{
|
||||
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
|
||||
assert(impl);
|
||||
|
|
@ -2068,15 +2207,20 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
|
|||
.lds_offs_primflags = gs_out_vtx_bytes,
|
||||
.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
|
||||
.provoking_vertex_last = provoking_vertex_last,
|
||||
.can_cull = can_cull,
|
||||
};
|
||||
|
||||
unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
|
||||
unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
|
||||
shader->info.shared_size = total_lds_bytes;
|
||||
|
||||
nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
|
||||
state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
|
||||
state.const_out_prmcnt[0] != -1;
|
||||
if (!can_cull) {
|
||||
nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
|
||||
state.const_out_prmcnt, 4u);
|
||||
state.output_compile_time_known =
|
||||
state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
|
||||
state.const_out_prmcnt[0] != -1;
|
||||
}
|
||||
|
||||
if (!state.output_compile_time_known)
|
||||
state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
|
||||
|
|
|
|||
|
|
@ -1342,7 +1342,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
|
|||
assert(info->is_ngg);
|
||||
NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
|
||||
info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
|
||||
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last);
|
||||
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false);
|
||||
} else if (nir->info.stage == MESA_SHADER_MESH) {
|
||||
bool scratch_ring = false;
|
||||
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue