ac/nir/ngg: Don't use LDS for same-invocation indices and cull outputs.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18566>
This commit is contained in:
Timur Kristóf 2022-09-10 01:26:10 +02:00 committed by Marge Bot
parent bb4bdba17e
commit 697ea02202

View file

@ -28,6 +28,11 @@
#include "u_math.h"
#include "u_vector.h"
#define SPECIAL_MS_OUT_MASK \
(BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
enum {
nggc_passflag_used_by_pos = 1,
nggc_passflag_used_by_other = 2,
@ -2858,6 +2863,12 @@ ms_store_prim_indices(nir_builder *b,
{
assert(val->num_components <= 3);
if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
for (unsigned c = 0; c < s->vertices_per_prim; ++c)
nir_store_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c], nir_channel(b, val, c), 0x1);
return;
}
if (!offset_src)
offset_src = nir_imm_int(b, 0);
@ -2901,6 +2912,11 @@ ms_store_cull_flag(nir_builder *b,
assert(val->num_components == 1);
assert(val->bit_size == 1);
if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, val), 0x1);
return;
}
if (!offset_src)
offset_src = nir_imm_int(b, 0);
@ -3626,7 +3642,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
{
/* Generic per-primitive attributes. */
ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs, s);
ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs & ~SPECIAL_MS_OUT_MASK, s);
/* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
if (s->insert_layer_output) {
@ -3639,19 +3655,35 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
/* Primitive connectivity data: describes which vertices the primitive uses. */
nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
nir_ssa_def *indices_loaded = NULL;
nir_ssa_def *cull_flag = NULL;
if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
nir_ssa_def *indices[3] = {0};
for (unsigned c = 0; c < s->vertices_per_prim; ++c)
indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
}
else {
indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
indices_loaded = nir_u2u32(b, indices_loaded);
}
if (s->uses_cull_flags) {
nir_ssa_def *loaded_cull_flag = nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr);
cull_flag = nir_i2b1(b, nir_u2u32(b, loaded_cull_flag));
nir_ssa_def *loaded_cull_flag = NULL;
if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
else
loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
cull_flag = nir_i2b1(b, loaded_cull_flag);
}
nir_ssa_def *indices[3];
nir_ssa_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
indices[i] = nir_u2u32(b, nir_channel(b, indices_loaded, i));
indices[i] = nir_channel(b, indices_loaded, i);
indices[i] = nir_umin(b, indices[i], max_vtx_idx);
}
@ -3834,11 +3866,16 @@ ms_calculate_output_layout(unsigned api_shared_size,
uint64_t cross_invocation_output_access,
unsigned max_vertices,
unsigned max_primitives,
unsigned vertices_per_prim,
bool uses_cull)
unsigned vertices_per_prim)
{
uint64_t lds_per_vertex_output_mask = per_vertex_output_mask & cross_invocation_output_access;
uint64_t lds_per_primitive_output_mask = per_primitive_output_mask & cross_invocation_output_access;
const uint64_t lds_per_vertex_output_mask =
per_vertex_output_mask & cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
const uint64_t lds_per_primitive_output_mask =
per_primitive_output_mask & cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
const bool cross_invocation_indices =
cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
const bool cross_invocation_cull_primitive =
cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
/* Shared memory used by the API shader. */
ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
@ -3869,7 +3906,9 @@ ms_calculate_output_layout(unsigned api_shared_size,
* Move the outputs that do not fit LDS, to VRAM.
* Start with per-primitive attributes, because those are grouped at the end.
*/
while (l.lds.total_size >= 30 * 1024) {
const unsigned usable_lds_kbytes =
(cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
while (l.lds.total_size >= usable_lds_kbytes * 1024) {
if (l.lds.prm_attr.mask)
ms_move_output(&l.lds.prm_attr, &l.vram.prm_attr);
else if (l.lds.vtx_attr.mask)
@ -3880,11 +3919,13 @@ ms_calculate_output_layout(unsigned api_shared_size,
ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
}
/* Indices: flat array of 8-bit vertex indices for each primitive. */
l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
if (cross_invocation_indices) {
/* Indices: flat array of 8-bit vertex indices for each primitive. */
l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
}
if (uses_cull) {
if (cross_invocation_cull_primitive) {
/* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
@ -3904,13 +3945,10 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
unsigned vertices_per_prim =
num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
uint64_t special_outputs =
BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) |
BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
uint64_t per_vertex_outputs =
shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~special_outputs;
shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
uint64_t per_primitive_outputs =
shader->info.per_primitive_outputs & shader->info.outputs_written & ~special_outputs;
shader->info.per_primitive_outputs & shader->info.outputs_written;
/* Whether the shader uses CullPrimitiveEXT */
bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
@ -3918,12 +3956,17 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
shader->info.outputs_accessed_indirectly;
if (shader->info.mesh.nv) {
per_primitive_outputs &= ~SPECIAL_MS_OUT_MASK;
cross_invocation_access |= SPECIAL_MS_OUT_MASK;
}
unsigned max_vertices = shader->info.mesh.max_vertices_out;
unsigned max_primitives = shader->info.mesh.max_primitives_out;
ms_out_mem_layout layout =
ms_calculate_output_layout(shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
cross_invocation_access, max_vertices, max_primitives, vertices_per_prim, uses_cull);
cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
shader->info.shared_size = layout.lds.total_size;
*out_needs_scratch_ring = layout.vram.vtx_attr.mask || layout.vram.prm_attr.mask;