diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 1965ef96e6c..8959741ac75 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -28,6 +28,11 @@ #include "u_math.h" #include "u_vector.h" +#define SPECIAL_MS_OUT_MASK \ + (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \ + BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \ + BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) + enum { nggc_passflag_used_by_pos = 1, nggc_passflag_used_by_other = 2, @@ -2858,6 +2863,12 @@ ms_store_prim_indices(nir_builder *b, { assert(val->num_components <= 3); + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { + for (unsigned c = 0; c < s->vertices_per_prim; ++c) + nir_store_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c], nir_channel(b, val, c), 0x1); + return; + } + if (!offset_src) offset_src = nir_imm_int(b, 0); @@ -2901,6 +2912,11 @@ ms_store_cull_flag(nir_builder *b, assert(val->num_components == 1); assert(val->bit_size == 1); + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) { + nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, val), 0x1); + return; + } + if (!offset_src) offset_src = nir_imm_int(b, 0); @@ -3626,7 +3642,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive); { /* Generic per-primitive attributes. */ - ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s); + ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs & ~SPECIAL_MS_OUT_MASK, s); /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */ if (s->insert_layer_output) { @@ -3639,19 +3655,35 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s) /* Primitive connectivity data: describes which vertices the primitive uses. */ nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim); - nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); + nir_ssa_def *indices_loaded = NULL; nir_ssa_def *cull_flag = NULL; + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) { + nir_ssa_def *indices[3] = {0}; + for (unsigned c = 0; c < s->vertices_per_prim; ++c) + indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]); + indices_loaded = nir_vec(b, indices, s->vertices_per_prim); + } + else { + indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr); + indices_loaded = nir_u2u32(b, indices_loaded); + } + if (s->uses_cull_flags) { - nir_ssa_def *loaded_cull_flag = nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr); - cull_flag = nir_i2b1(b, nir_u2u32(b, loaded_cull_flag)); + nir_ssa_def *loaded_cull_flag = NULL; + if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) + loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]); + else + loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr)); + + cull_flag = nir_i2b1(b, loaded_cull_flag); } nir_ssa_def *indices[3]; nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u); for (unsigned i = 0; i < s->vertices_per_prim; ++i) { - indices[i] = nir_u2u32(b, nir_channel(b, indices_loaded, i)); + indices[i] = nir_channel(b, indices_loaded, i); indices[i] = nir_umin(b, indices[i], max_vtx_idx); } @@ -3834,11 +3866,16 @@ ms_calculate_output_layout(unsigned api_shared_size, uint64_t cross_invocation_output_access, unsigned max_vertices, unsigned max_primitives, - unsigned vertices_per_prim, - bool uses_cull) + unsigned vertices_per_prim) { - uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access; - uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access; + const uint64_t lds_per_vertex_output_mask = + per_vertex_output_mask & cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK; + const uint64_t lds_per_primitive_output_mask = + per_primitive_output_mask & cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK; + const bool cross_invocation_indices = + cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES); + const bool cross_invocation_cull_primitive = + cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); /* Shared memory used by the API shader. */ ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } }; @@ -3869,7 +3906,9 @@ ms_calculate_output_layout(unsigned api_shared_size, * Move the outputs that do not fit LDS, to VRAM. * Start with per-primitive attributes, because those are grouped at the end. */ - while (l.lds.total_size >= 30 * 1024) { + const unsigned usable_lds_kbytes = + (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31; + while (l.lds.total_size >= usable_lds_kbytes * 1024) { if (l.lds.prm_attr.mask) ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr); else if (l.lds.vtx_attr.mask) @@ -3880,11 +3919,13 @@ ms_calculate_output_layout(unsigned api_shared_size, ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives); } - /* Indices: flat array of 8-bit vertex indices for each primitive. */ - l.lds.indices_addr = ALIGN(l.lds.total_size, 16); - l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; + if (cross_invocation_indices) { + /* Indices: flat array of 8-bit vertex indices for each primitive. */ + l.lds.indices_addr = ALIGN(l.lds.total_size, 16); + l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim; + } - if (uses_cull) { + if (cross_invocation_cull_primitive) { /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */ l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16); l.lds.total_size = l.lds.cull_flags_addr + max_primitives; @@ -3904,13 +3945,10 @@ ac_nir_lower_ngg_ms(nir_shader *shader, unsigned vertices_per_prim = num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type); - uint64_t special_outputs = - BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | - BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); uint64_t per_vertex_outputs = - shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs; + shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK; uint64_t per_primitive_outputs = - shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs; + shader->info.per_primitive_outputs & shader->info.outputs_written; /* Whether the shader uses CullPrimitiveEXT */ bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE); @@ -3918,12 +3956,17 @@ ac_nir_lower_ngg_ms(nir_shader *shader, uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access | shader->info.outputs_accessed_indirectly; + if (shader->info.mesh.nv) { + per_primitive_outputs &= ~SPECIAL_MS_OUT_MASK; + cross_invocation_access |= SPECIAL_MS_OUT_MASK; + } + unsigned max_vertices = shader->info.mesh.max_vertices_out; unsigned max_primitives = shader->info.mesh.max_primitives_out; ms_out_mem_layout layout = ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs, - cross_invocation_access, max_vertices, max_primitives, vertices_per_prim, uses_cull); + cross_invocation_access, max_vertices, max_primitives, vertices_per_prim); shader->info.shared_size = layout.lds.total_size; *out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;