brw: fix local_invocation_index with quad derivaties on mesh/task shaders

For mesh/task shaders, the thread payload provides a local invocation
index, but it's always linear so it doesn't give the correct value when
quad derivatives are in use.
The lowering pass where all of this is done correctly for compute
shaders assumes load_local_invocation_index will be lowered in the
backend for mesh/task, calculates the values for the quads correctly but
then avoid replacing the original intrinsic and we remain with the wrong
results.

Add an intel specific intrinsic and always lower the generic one to that
(or whatever else was calculated) to avoid ambiguities and fix the value
for quad derivatives.

Fixes future CTS tests using mesh/task shaders under:
dEQP-VK.spirv_assembly.instruction.compute.compute_shader_derivatives.*

Fixes: d89bfb1ff7 ("intel/brw: Reorganize lowering of LocalID/Index to handle Mesh/Task")
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
(cherry picked from commit 5b48805b42)

Conflicts:
	src/compiler/nir/nir_divergence_analysis.c
	src/intel/compiler/brw/brw_compile_mesh.cpp
	src/intel/vulkan/anv_shader_compile.c

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39745>
This commit is contained in:
Iván Briano 2026-01-12 13:18:41 -08:00 committed by Dylan Baker
parent 976bd36982
commit 3fc2b77823
7 changed files with 16 additions and 17 deletions

View file

@ -354,7 +354,7 @@
"description": "brw: fix local_invocation_index with quad derivaties on mesh/task shaders",
"nominated": true,
"nomination_type": 2,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": "d89bfb1ff750445f77717ea44884decf93adad97",
"notes": null

View file

@ -989,6 +989,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_cmat_load_shared_nv:
case nir_intrinsic_load_converted_output_pan:
case nir_intrinsic_load_readonly_output_pan:
case nir_intrinsic_load_local_invocation_index_intel:
is_divergent = true;
break;

View file

@ -2483,6 +2483,9 @@ system_value("fs_msaa_intel", 1)
# Per primitive remapping table offset.
system_value("per_primitive_remap_intel", 1)
# The (linear) local invocation index provided in the payload of mesh/task shaders.
system_value("local_invocation_index_intel", 1)
# Intrinsics for Intel bindless thread dispatch
# BASE=brw_topoloy_id
system_value("topology_id_intel", 1, indices=[BASE])

View file

@ -372,6 +372,9 @@ brw_compile_task(const struct brw_compiler *compiler,
NIR_PASS(_, nir, brw_nir_lower_launch_mesh_workgroups);
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
NULL);
brw_prog_data_init(&prog_data->base.base, &params->base);
prog_data->base.local_size[0] = nir->info.workgroup_size[0];
@ -1231,6 +1234,10 @@ brw_compile_mesh(const struct brw_compiler *compiler,
apply_wa_18019110168 ? wa_18019110168_mapping : NULL);
brw_nir_lower_mue_outputs(nir, &prog_data->map);
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
NULL);
prog_data->autostrip_enable = brw_mesh_autostrip_enable(compiler, nir, &prog_data->map);
prog_data->base.uses_inline_data = brw_nir_uses_inline_data(nir) ||

View file

@ -5753,7 +5753,7 @@ brw_from_nir_emit_task_mesh_intrinsic(nir_to_brw_state &ntb, const brw_builder &
UNREACHABLE("local invocation id should have been lowered earlier");
break;
case nir_intrinsic_load_local_invocation_index:
case nir_intrinsic_load_local_invocation_index_intel:
dest = retype(dest, BRW_TYPE_UD);
bld.MOV(dest, payload.local_index);
break;

View file

@ -33,7 +33,6 @@ struct lower_intrinsics_state {
/* Per-block cached values. */
bool computed;
nir_def *hw_index;
nir_def *local_index;
nir_def *local_id;
};
@ -42,7 +41,6 @@ static void
compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr *current)
{
assert(!state->computed);
state->hw_index = NULL;
state->local_index = NULL;
state->local_id = NULL;
state->computed = true;
@ -86,13 +84,8 @@ compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr
nir_def *linear;
if (nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK) {
/* Thread payload provides a linear index, keep track of it
* so it doesn't get removed.
*/
state->hw_index =
current->intrinsic == nir_intrinsic_load_local_invocation_index ?
&current->def : nir_load_local_invocation_index(b);
linear = state->hw_index;
/* Thread payload provides a linear index, just use that. */
linear = nir_load_local_invocation_index_intel(b);
} else {
nir_def *subgroup_id = nir_load_subgroup_id(b);
nir_def *thread_local_id =
@ -263,10 +256,6 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
if (!state->computed)
compute_local_index_id(state, intrinsic);
/* Will be lowered later by the backend. */
if (&intrinsic->def == state->hw_index)
continue;
assert(state->local_index);
sysval = state->local_index;
break;

View file

@ -1571,8 +1571,7 @@ anv_shader_lower_nir(struct anv_device *device,
}
}
if (mesa_shader_stage_is_compute(nir->info.stage) ||
mesa_shader_stage_is_mesh(nir->info.stage)) {
if (mesa_shader_stage_is_compute(nir->info.stage)) {
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
&shader_data->prog_data.cs);
}