anv/brw: rework primitive count writing

Instead the complicated logic we currently have, do this :

We start with this shader :

int main() {
   ...
   if (...) {
      SetMeshOutputsEXT(0, 0);
      return;
   } else {
      SetMeshOutputsEXT(...);
   }
   ...
}

We turn it into this :

int main() {
   uint __temp_prim_count = 0;
   ...
   if (...) {
      __temp_prim_count = 0;
      return;
   } else {
      __temp_prim_count = ...;
   }
   ...

   if (is_first_group_lane()) {
      SetMeshOutputsEXT(..., __temp_prim_count);
   }
}

This works because the SPIRV spec says this :

   "The arguments are taken from the first invocation in each
    workgroup. Any invocation must execute this instruction no more
    than once and under uniform control flow."

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Cc: mesa-stable
Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/12388
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33038>
This commit is contained in:
Lionel Landwerlin 2025-01-15 15:45:07 +02:00 committed by Marge Bot
parent 4cc847cfd4
commit 7ddb49653d
2 changed files with 72 additions and 87 deletions

View file

@ -265,6 +265,74 @@ brw_nir_align_launch_mesh_workgroups(nir_shader *nir)
NULL);
}
static bool
lower_set_vtx_and_prim_to_temp_write(nir_builder *b,
nir_intrinsic_instr *intrin,
void *data)
{
if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count)
return false;
/* Detect some cases of invalid primitive count. They might lead to URB
* memory corruption, where workgroups overwrite each other output memory.
*/
if (nir_src_is_const(intrin->src[1]) &&
nir_src_as_uint(intrin->src[1]) > b->shader->info.mesh.max_primitives_out)
unreachable("number of primitives bigger than max specified");
b->cursor = nir_instr_remove(&intrin->instr);
nir_variable *temporary_primitive_count = (nir_variable *)data;
nir_store_var(b, temporary_primitive_count, intrin->src[1].ssa, 0x1);
return true;
}
static bool
brw_nir_lower_mesh_primitive_count(nir_shader *nir)
{
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
nir_variable *temporary_primitive_count =
nir_local_variable_create(impl,
glsl_uint_type(),
"__temp_primitive_count");
nir_shader_intrinsics_pass(nir,
lower_set_vtx_and_prim_to_temp_write,
nir_metadata_control_flow,
temporary_primitive_count);
nir_builder _b = nir_builder_at(nir_before_impl(impl)), *b = &_b;
nir_store_var(b, temporary_primitive_count, nir_imm_int(b, 0), 0x1);
b->cursor = nir_after_impl(impl);
/* Have a single lane write the primitive count */
nir_def *local_invocation_index = nir_load_local_invocation_index(b);
nir_push_if(b, nir_ieq_imm(b, local_invocation_index, 0));
{
nir_variable *final_primitive_count =
nir_create_variable_with_location(nir, nir_var_shader_out,
VARYING_SLOT_PRIMITIVE_COUNT,
glsl_uint_type());
final_primitive_count->name = ralloc_strdup(final_primitive_count,
"gl_PrimitiveCountNV");
final_primitive_count->data.interpolation = INTERP_MODE_NONE;
nir_store_var(b, final_primitive_count,
nir_load_var(b, temporary_primitive_count), 0x1);
}
nir_pop_if(b, NULL);
nir_metadata_preserve(impl, nir_metadata_none);
nir->info.outputs_written |= VARYING_BIT_PRIMITIVE_COUNT;
return true;
}
static void
brw_emit_urb_fence(fs_visitor &s)
{
@ -1651,6 +1719,10 @@ brw_compile_mesh(const struct brw_compiler *compiler,
prog_data->uses_drawid =
BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID);
NIR_PASS(_, nir, brw_nir_lower_mesh_primitive_count);
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
brw_nir_lower_tue_inputs(nir, params->tue_map);
brw_compute_mue_map(compiler, nir, &prog_data->map,

View file

@ -46,87 +46,6 @@
#include "vk_render_pass.h"
#include "vk_util.h"
struct lower_set_vtx_and_prim_count_state {
nir_variable *primitive_count;
};
static nir_variable *
anv_nir_prim_count_store(nir_builder *b, nir_def *val)
{
nir_variable *primitive_count =
nir_variable_create(b->shader,
nir_var_shader_out,
glsl_uint_type(),
"gl_PrimitiveCountNV");
primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
primitive_count->data.interpolation = INTERP_MODE_NONE;
nir_def *local_invocation_index = nir_load_local_invocation_index(b);
nir_def *cmp = nir_ieq_imm(b, local_invocation_index, 0);
nir_if *if_stmt = nir_push_if(b, cmp);
{
nir_deref_instr *prim_count_deref = nir_build_deref_var(b, primitive_count);
nir_store_deref(b, prim_count_deref, val, 1);
}
nir_pop_if(b, if_stmt);
return primitive_count;
}
static bool
anv_nir_lower_set_vtx_and_prim_count_instr(nir_builder *b,
nir_intrinsic_instr *intrin,
void *data)
{
if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count)
return false;
/* Detect some cases of invalid primitive count. They might lead to URB
* memory corruption, where workgroups overwrite each other output memory.
*/
if (nir_src_is_const(intrin->src[1]) &&
nir_src_as_uint(intrin->src[1]) > b->shader->info.mesh.max_primitives_out) {
assert(!"number of primitives bigger than max specified");
}
struct lower_set_vtx_and_prim_count_state *state = data;
/* this intrinsic should show up only once */
assert(state->primitive_count == NULL);
b->cursor = nir_before_instr(&intrin->instr);
state->primitive_count = anv_nir_prim_count_store(b, intrin->src[1].ssa);
nir_instr_remove(&intrin->instr);
return true;
}
static bool
anv_nir_lower_set_vtx_and_prim_count(nir_shader *nir)
{
struct lower_set_vtx_and_prim_count_state state = { NULL, };
nir_shader_intrinsics_pass(nir, anv_nir_lower_set_vtx_and_prim_count_instr,
nir_metadata_none,
&state);
/* If we didn't find set_vertex_and_primitive_count, then we have to
* insert store of value 0 to primitive_count.
*/
if (state.primitive_count == NULL) {
nir_builder b;
nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
b = nir_builder_at(nir_before_impl(entrypoint));
nir_def *zero = nir_imm_int(&b, 0);
state.primitive_count = anv_nir_prim_count_store(&b, zero);
}
assert(state.primitive_count != NULL);
return true;
}
/* Eventually, this will become part of anv_CreateShader. Unfortunately,
* we can't do that yet because we don't have the ability to copy nir.
*/
@ -2088,12 +2007,6 @@ anv_pipeline_nir_preprocess(struct anv_pipeline *pipeline,
};
brw_preprocess_nir(compiler, stage->nir, &opts);
if (stage->nir->info.stage == MESA_SHADER_MESH) {
NIR_PASS(_, stage->nir, anv_nir_lower_set_vtx_and_prim_count);
NIR_PASS(_, stage->nir, nir_opt_dce);
NIR_PASS(_, stage->nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
}
NIR_PASS(_, stage->nir, nir_opt_barrier_modes);
nir_shader_gather_info(stage->nir, nir_shader_get_entrypoint(stage->nir));