anv: handle mesh shaders with max primitives == 0

Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20279>
This commit is contained in:
Marcin Ślusarz 2022-12-12 14:28:05 +01:00 committed by Marge Bot
parent c26a053f2b
commit d7a1916798
2 changed files with 54 additions and 34 deletions

View file

@ -45,13 +45,37 @@
/* Needed for SWIZZLE macros */
#include "program/prog_instruction.h"
struct lower_mesh_ext_state {
struct lower_set_vtx_and_prim_count_state {
nir_variable *primitive_count;
nir_variable *primitive_indices;
};
static nir_variable *
anv_nir_prim_count_store(nir_builder *b, nir_ssa_def *val)
{
nir_variable *primitive_count =
nir_variable_create(b->shader,
nir_var_shader_out,
glsl_uint_type(),
"gl_PrimitiveCountNV");
primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
primitive_count->data.interpolation = INTERP_MODE_NONE;
nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b);
nir_ssa_def *cmp = nir_ieq(b, local_invocation_index,
nir_imm_int(b, 0));
nir_if *if_stmt = nir_push_if(b, cmp);
{
nir_deref_instr *prim_count_deref = nir_build_deref_var(b, primitive_count);
nir_store_deref(b, prim_count_deref, val, 1);
}
nir_pop_if(b, if_stmt);
return primitive_count;
}
static bool
anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
anv_nir_lower_set_vtx_and_prim_count_instr(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
@ -60,30 +84,13 @@ anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count)
return false;
struct lower_mesh_ext_state *state = data;
struct lower_set_vtx_and_prim_count_state *state = data;
/* this intrinsic should show up only once */
assert(state->primitive_count == NULL);
state->primitive_count =
nir_variable_create(b->shader,
nir_var_shader_out,
glsl_uint_type(),
"gl_PrimitiveCountNV");
state->primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
state->primitive_count->data.interpolation = INTERP_MODE_NONE;
b->cursor = nir_before_instr(&intrin->instr);
nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b);
nir_ssa_def *cmp = nir_ieq(b, local_invocation_index,
nir_imm_int(b, 0));
nir_if *if_stmt = nir_push_if(b, cmp);
{
nir_deref_instr *prim_count_deref = nir_build_deref_var(b, state->primitive_count);
nir_store_deref(b, prim_count_deref, intrin->src[1].ssa, 1);
}
nir_pop_if(b, if_stmt);
state->primitive_count = anv_nir_prim_count_store(b, intrin->src[1].ssa);
nir_instr_remove(instr);
@ -91,13 +98,29 @@ anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
}
static bool
anv_nir_lower_mesh_ext(nir_shader *nir)
anv_nir_lower_set_vtx_and_prim_count(nir_shader *nir)
{
struct lower_mesh_ext_state state = { NULL, };
struct lower_set_vtx_and_prim_count_state state = { NULL, };
return nir_shader_instructions_pass(nir, anv_nir_lower_mesh_ext_instr,
nir_metadata_none,
&state);
nir_shader_instructions_pass(nir,
anv_nir_lower_set_vtx_and_prim_count_instr,
nir_metadata_none,
&state);
/* If we didn't find set_vertex_and_primitive_count, then we have to
* insert store of value 0 to primitive_count.
*/
if (state.primitive_count == NULL) {
nir_builder b;
nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
nir_builder_init(&b, entrypoint);
b.cursor = nir_before_block(nir_start_block(entrypoint));
nir_ssa_def *zero = nir_imm_int(&b, 0);
state.primitive_count = anv_nir_prim_count_store(&b, zero);
}
assert(state.primitive_count != NULL);
return true;
}
/* Eventually, this will become part of anv_CreateShader. Unfortunately,
@ -231,12 +254,9 @@ anv_shader_stage_to_nir(struct anv_device *device,
brw_preprocess_nir(compiler, nir, &opts);
if (nir->info.stage == MESA_SHADER_MESH && !nir->info.mesh.nv) {
bool progress = false;
NIR_PASS(progress, nir, anv_nir_lower_mesh_ext);
if (progress) {
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
}
NIR_PASS(_, nir, anv_nir_lower_set_vtx_and_prim_count);
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
}
return nir;

View file

@ -1806,7 +1806,7 @@ emit_mesh_state(struct anv_graphics_pipeline *pipeline)
mesh.LocalXMaximum = mesh_dispatch.group_size - 1;
mesh.EmitLocalIDX = true;
mesh.MaximumPrimitiveCount = mesh_prog_data->map.max_primitives - 1;
mesh.MaximumPrimitiveCount = MAX2(mesh_prog_data->map.max_primitives, 1) - 1;
mesh.OutputTopology = output_topology;
mesh.PerVertexDataPitch = mesh_prog_data->map.per_vertex_pitch_dw / 8;
mesh.PerPrimitiveDataPresent = mesh_prog_data->map.per_primitive_pitch_dw > 0;