diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 52e2992a23c..5b9b95255a9 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -144,6 +144,7 @@ typedef struct ms_out_part vtx_attr; ms_out_part prm_attr; uint32_t indices_addr; + uint32_t cull_flags_addr; uint32_t total_size; } lds; /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS. */ @@ -174,6 +175,8 @@ typedef struct /* True if the lowering needs to insert the layer output. */ bool insert_layer_output; + /* True if cull flags are used */ + bool uses_cull_flags; struct { /* Bitmask of components used: 4 bits per slot, 1 bit per component. */ @@ -2439,6 +2442,21 @@ ms_load_num_prims(nir_builder *b, return nir_load_shared(b, 1, 32, addr, .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims); } +static void +ms_store_cull_flag(nir_builder *b, + nir_ssa_def *val, + nir_ssa_def *offset_src, + lower_ngg_ms_state *s) +{ + assert(val->num_components == 1); + assert(val->bit_size == 1); + + if (!offset_src) + offset_src = nir_imm_int(b, 0); + + nir_store_shared(b, nir_b2i8(b, val), offset_src, .base = s->layout.lds.cull_flags_addr); +} + static nir_ssa_def * lower_ms_store_output(nir_builder *b, nir_intrinsic_instr *intrin, @@ -2654,6 +2672,20 @@ ms_store_arrayed_output_intrin(nir_builder *b, nir_ssa_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); ms_store_prim_indices(b, store_val, offset, s); return; + } else if (location == VARYING_SLOT_CULL_PRIMITIVE) { + /* EXT_mesh_shader cull primitive: per-primitive bool. + * To reduce LDS use, store these as an array of 8-bit values. + */ + assert(nir_src_is_const(*nir_get_io_offset_src(intrin))); + assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0); + assert(nir_intrinsic_component(intrin) == 0); + assert(nir_intrinsic_write_mask(intrin) == 1); + + nir_ssa_def *store_val = intrin->src[0].ssa; + nir_ssa_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa; + nir_ssa_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim); + ms_store_cull_flag(b, store_val, offset, s); + return; } ms_out_mode out_mode; @@ -3054,6 +3086,13 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) /* Primitive connectivity data: describes which vertices the primitive uses. */ nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim); nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); + nir_ssa_def *cull_flag = NULL; + + if (s->uses_cull_flags) { + nir_ssa_def *loaded_cull_flag = nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr); + cull_flag = nir_i2b1(b, nir_u2u32(b, loaded_cull_flag)); + } + nir_ssa_def *indices[3]; nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u); @@ -3062,7 +3101,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) indices[i] = nir_umin(b, indices[i], max_vtx_idx); } - nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, NULL, false); + nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, false); nir_export_primitive_amd(b, prim_exp_arg); } nir_pop_if(b, if_has_output_primitive); @@ -3241,7 +3280,8 @@ ms_calculate_output_layout(unsigned api_shared_size, uint64_t cross_invocation_output_access, unsigned max_vertices, unsigned max_primitives, - unsigned vertices_per_prim) + unsigned vertices_per_prim, + bool uses_cull) { uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access; uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access; @@ -3290,6 +3330,12 @@ ms_calculate_output_layout(unsigned api_shared_size, l.lds.indices_addr = ALIGN(l.lds.total_size, 16); l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; + if (uses_cull) { + /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */ + l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16); + l.lds.total_size = l.lds.cull_flags_addr + max_primitives; + } + /* NGG is only allowed to address up to 32K of LDS. */ assert(l.lds.total_size <= 32 * 1024); return l; @@ -3305,12 +3351,15 @@ ac_nir_lower_ngg_ms(nir_shader *shader, num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type); uint64_t special_outputs = - BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES); + BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | + BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); uint64_t per_vertex_outputs = shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs; uint64_t per_primitive_outputs = shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs; + /* Whether the shader uses CullPrimitiveEXT */ + bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */ uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | shader->info.outputs_accessed_indirectly; @@ -3320,7 +3369,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader, ms_out_mem_layout layout = ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, - cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); + cross_invocation_access, max_vertices, max_primitives, vertices_per_prim, uses_cull); shader->info.shared_size = layout.lds.total_size; *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask; @@ -3348,6 +3397,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader, .api_workgroup_size = api_workgroup_size, .hw_workgroup_size = hw_workgroup_size, .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER), + .uses_cull_flags = uses_cull, }; nir_function_impl *impl = nir_shader_get_entrypoint(shader);