brw: fix local_invocation_index with quad derivaties on mesh/task shaders

For mesh/task shaders, the thread payload provides a local invocation
index, but it's always linear so it doesn't give the correct value when
quad derivatives are in use.
The lowering pass where all of this is done correctly for compute
shaders assumes load_local_invocation_index will be lowered in the
backend for mesh/task, calculates the values for the quads correctly but
then avoid replacing the original intrinsic and we remain with the wrong
results.

Add an intel specific intrinsic and always lower the generic one to that
(or whatever else was calculated) to avoid ambiguities and fix the value
for quad derivatives.

Fixes future CTS tests using mesh/task shaders under:
dEQP-VK.spirv_assembly.instruction.compute.compute_shader_derivatives.*

Fixes: d89bfb1ff7 ("intel/brw: Reorganize lowering of LocalID/Index to handle Mesh/Task")
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39276>
This commit is contained in:
Iván Briano 2026-01-12 13:18:41 -08:00 committed by Marge Bot
parent eb990cd81e
commit 5b48805b42
6 changed files with 14 additions and 16 deletions

View file

@ -1021,6 +1021,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_atest_pan:
case nir_intrinsic_zs_emit_pan:
case nir_intrinsic_load_return_param_amd:
case nir_intrinsic_load_local_invocation_index_intel:
is_divergent = true;
break;

View file

@ -2633,6 +2633,9 @@ system_value("fs_msaa_intel", 1)
# Per primitive remapping table offset.
system_value("per_primitive_remap_intel", 1)
# The (linear) local invocation index provided in the payload of mesh/task shaders.
system_value("local_invocation_index_intel", 1)
# Intrinsics for Intel bindless thread dispatch
# BASE=brw_topoloy_id
system_value("topology_id_intel", 1, indices=[BASE])

View file

@ -297,6 +297,9 @@ brw_compile_task(const struct brw_compiler *compiler,
NIR_PASS(_, nir, brw_nir_lower_launch_mesh_workgroups);
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
NULL);
brw_prog_data_init(&prog_data->base.base, &params->base);
prog_data->base.local_size[0] = nir->info.workgroup_size[0];
@ -1015,6 +1018,9 @@ brw_compile_mesh(const struct brw_compiler *compiler,
if (prog_data->map.has_per_primitive_header)
NIR_PASS(_, nir, brw_nir_initialize_mue, &prog_data->map);
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
NULL);
prog_data->autostrip_enable = brw_mesh_autostrip_enable(compiler, nir, &prog_data->map);
prog_data->base.uses_inline_data = brw_nir_uses_inline_data(nir) ||

View file

@ -4857,7 +4857,7 @@ brw_from_nir_emit_task_mesh_intrinsic(nir_to_brw_state &ntb,
UNREACHABLE("local invocation id should have been lowered earlier");
break;
case nir_intrinsic_load_local_invocation_index:
case nir_intrinsic_load_local_invocation_index_intel:
dest = retype(dest, BRW_TYPE_UD);
bld.MOV(dest, payload.local_index);
break;

View file

@ -15,7 +15,6 @@ struct lower_intrinsics_state {
/* Per-block cached values. */
bool computed;
nir_def *hw_index;
nir_def *local_index;
nir_def *local_id;
};
@ -24,7 +23,6 @@ static void
compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr *current)
{
assert(!state->computed);
state->hw_index = NULL;
state->local_index = NULL;
state->local_id = NULL;
state->computed = true;
@ -68,13 +66,8 @@ compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr
nir_def *linear;
if (nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK) {
/* Thread payload provides a linear index, keep track of it
* so it doesn't get removed.
*/
state->hw_index =
current->intrinsic == nir_intrinsic_load_local_invocation_index ?
&current->def : nir_load_local_invocation_index(b);
linear = state->hw_index;
/* Thread payload provides a linear index, just use that. */
linear = nir_load_local_invocation_index_intel(b);
} else {
nir_def *subgroup_id = nir_load_subgroup_id(b);
nir_def *thread_local_id =
@ -245,10 +238,6 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
if (!state->computed)
compute_local_index_id(state, intrinsic);
/* Will be lowered later by the backend. */
if (&intrinsic->def == state->hw_index)
continue;
assert(state->local_index);
sysval = state->local_index;
break;

View file

@ -1564,8 +1564,7 @@ anv_shader_lower_nir(struct anv_device *device,
}
}
if (mesa_shader_stage_is_compute(nir->info.stage) ||
mesa_shader_stage_is_mesh(nir->info.stage)) {
if (mesa_shader_stage_is_compute(nir->info.stage)) {
NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo,
&shader_data->prog_data.cs);
}