ac/nir/ngg: Add EXT_mesh_shader CullPrimitiveEXT output.

This is a per-primitive boolean output.
When set to 1, the primitive should be culled.

Implement this by using this boolean as the null primitive
flag for primitive exports.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18367>
This commit is contained in:
Timur Kristóf 2022-02-28 20:12:00 +01:00 committed by Marge Bot
parent 1f8f4570f0
commit 448d09d44a

View file

@ -144,6 +144,7 @@ typedef struct
ms_out_part vtx_attr;
ms_out_part prm_attr;
uint32_t indices_addr;
uint32_t cull_flags_addr;
uint32_t total_size;
} lds;
/* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */
@ -174,6 +175,8 @@ typedef struct
/* True if the lowering needs to insert the layer output. */
bool insert_layer_output;
/* True if cull flags are used */
bool uses_cull_flags;
struct {
/* Bitmask of components used: 4 bits per slot, 1 bit per component. */
@ -2439,6 +2442,21 @@ ms_load_num_prims(nir_builder *b,
return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
}
static void
ms_store_cull_flag(nir_builder *b,
nir_ssa_def *val,
nir_ssa_def *offset_src,
lower_ngg_ms_state *s)
{
assert(val->num_components == 1);
assert(val->bit_size == 1);
if (!offset_src)
offset_src = nir_imm_int(b, 0);
nir_store_shared(b, nir_b2i8(b, val), offset_src, .base = s->layout.lds.cull_flags_addr);
}
static nir_ssa_def *
lower_ms_store_output(nir_builder *b,
nir_intrinsic_instr *intrin,
@ -2654,6 +2672,20 @@ ms_store_arrayed_output_intrin(nir_builder *b,
nir_ssa_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
ms_store_prim_indices(b, store_val, offset, s);
return;
} else if (location == VARYING_SLOT_CULL_PRIMITIVE) {
/* EXT_mesh_shader cull primitive: per-primitive bool.
* To reduce LDS use, store these as an array of 8-bit values.
*/
assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
assert(nir_intrinsic_component(intrin) == 0);
assert(nir_intrinsic_write_mask(intrin) == 1);
nir_ssa_def *store_val = intrin->src[0].ssa;
nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
nir_ssa_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
ms_store_cull_flag(b, store_val, offset, s);
return;
}
ms_out_mode out_mode;
@ -3054,6 +3086,13 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
/* Primitive connectivity data: describes which vertices the primitive uses. */
nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
nir_ssa_def *cull_flag = NULL;
if (s->uses_cull_flags) {
nir_ssa_def *loaded_cull_flag = nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr);
cull_flag = nir_i2b1(b, nir_u2u32(b, loaded_cull_flag));
}
nir_ssa_def *indices[3];
nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
@ -3062,7 +3101,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
indices[i] = nir_umin(b, indices[i], max_vtx_idx);
}
nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false);
nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, false);
nir_export_primitive_amd(b, prim_exp_arg);
}
nir_pop_if(b, if_has_output_primitive);
@ -3241,7 +3280,8 @@ ms_calculate_output_layout(unsigned api_shared_size,
uint64_t cross_invocation_output_access,
unsigned max_vertices,
unsigned max_primitives,
unsigned vertices_per_prim)
unsigned vertices_per_prim,
bool uses_cull)
{
uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access;
uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access;
@ -3290,6 +3330,12 @@ ms_calculate_output_layout(unsigned api_shared_size,
l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
if (uses_cull) {
/* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
}
/* NGG is only allowed to address up to 32K of LDS. */
assert(l.lds.total_size <= 32 * 1024);
return l;
@ -3305,12 +3351,15 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
uint64_t special_outputs =
BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) |
BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
uint64_t per_vertex_outputs =
shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs;
uint64_t per_primitive_outputs =
shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
/* Whether the shader uses CullPrimitiveEXT */
bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
/* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
shader->info.outputs_accessed_indirectly;
@ -3320,7 +3369,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
ms_out_mem_layout layout =
ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
cross_invocation_access, max_vertices, max_primitives, vertices_per_prim, uses_cull);
shader->info.shared_size = layout.lds.total_size;
*out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;
@ -3348,6 +3397,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
.api_workgroup_size = api_workgroup_size,
.hw_workgroup_size = hw_workgroup_size,
.insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
.uses_cull_flags = uses_cull,
};
nir_function_impl *impl = nir_shader_get_entrypoint(shader);