From 5b48805b42739dffbc52805cb4a3eddca7314a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Briano?= Date: Mon, 12 Jan 2026 13:18:41 -0800 Subject: [PATCH] brw: fix local_invocation_index with quad derivaties on mesh/task shaders For mesh/task shaders, the thread payload provides a local invocation index, but it's always linear so it doesn't give the correct value when quad derivatives are in use. The lowering pass where all of this is done correctly for compute shaders assumes load_local_invocation_index will be lowered in the backend for mesh/task, calculates the values for the quads correctly but then avoid replacing the original intrinsic and we remain with the wrong results. Add an intel specific intrinsic and always lower the generic one to that (or whatever else was calculated) to avoid ambiguities and fix the value for quad derivatives. Fixes future CTS tests using mesh/task shaders under: dEQP-VK.spirv_assembly.instruction.compute.compute_shader_derivatives.* Fixes: d89bfb1ff75 ("intel/brw: Reorganize lowering of LocalID/Index to handle Mesh/Task") Reviewed-by: Lionel Landwerlin Reviewed-by: Alyssa Rosenzweig Part-of: --- src/compiler/nir/nir_divergence_analysis.c | 1 + src/compiler/nir/nir_intrinsics.py | 3 +++ src/intel/compiler/brw/brw_compile_mesh.cpp | 6 ++++++ src/intel/compiler/brw/brw_from_nir.cpp | 2 +- .../compiler/brw/brw_nir_lower_cs_intrinsics.c | 15 ++------------- src/intel/vulkan/anv_shader_compile.c | 3 +-- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 60f7880fe0b..93a43961f14 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -1021,6 +1021,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state) case nir_intrinsic_atest_pan: case nir_intrinsic_zs_emit_pan: case nir_intrinsic_load_return_param_amd: + case nir_intrinsic_load_local_invocation_index_intel: is_divergent = true; break; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index b70ef9b774a..7514ce05116 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -2633,6 +2633,9 @@ system_value("fs_msaa_intel", 1) # Per primitive remapping table offset. system_value("per_primitive_remap_intel", 1) +# The (linear) local invocation index provided in the payload of mesh/task shaders. +system_value("local_invocation_index_intel", 1) + # Intrinsics for Intel bindless thread dispatch # BASE=brw_topoloy_id system_value("topology_id_intel", 1, indices=[BASE]) diff --git a/src/intel/compiler/brw/brw_compile_mesh.cpp b/src/intel/compiler/brw/brw_compile_mesh.cpp index a91bb30e402..5379316c96a 100644 --- a/src/intel/compiler/brw/brw_compile_mesh.cpp +++ b/src/intel/compiler/brw/brw_compile_mesh.cpp @@ -297,6 +297,9 @@ brw_compile_task(const struct brw_compiler *compiler, NIR_PASS(_, nir, brw_nir_lower_launch_mesh_workgroups); + NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo, + NULL); + brw_prog_data_init(&prog_data->base.base, ¶ms->base); prog_data->base.local_size[0] = nir->info.workgroup_size[0]; @@ -1015,6 +1018,9 @@ brw_compile_mesh(const struct brw_compiler *compiler, if (prog_data->map.has_per_primitive_header) NIR_PASS(_, nir, brw_nir_initialize_mue, &prog_data->map); + NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo, + NULL); + prog_data->autostrip_enable = brw_mesh_autostrip_enable(compiler, nir, &prog_data->map); prog_data->base.uses_inline_data = brw_nir_uses_inline_data(nir) || diff --git a/src/intel/compiler/brw/brw_from_nir.cpp b/src/intel/compiler/brw/brw_from_nir.cpp index f59cdd5ee5b..fcda6305ed9 100644 --- a/src/intel/compiler/brw/brw_from_nir.cpp +++ b/src/intel/compiler/brw/brw_from_nir.cpp @@ -4857,7 +4857,7 @@ brw_from_nir_emit_task_mesh_intrinsic(nir_to_brw_state &ntb, UNREACHABLE("local invocation id should have been lowered earlier"); break; - case nir_intrinsic_load_local_invocation_index: + case nir_intrinsic_load_local_invocation_index_intel: dest = retype(dest, BRW_TYPE_UD); bld.MOV(dest, payload.local_index); break; diff --git a/src/intel/compiler/brw/brw_nir_lower_cs_intrinsics.c b/src/intel/compiler/brw/brw_nir_lower_cs_intrinsics.c index 10ecac944e3..7777a87ea84 100644 --- a/src/intel/compiler/brw/brw_nir_lower_cs_intrinsics.c +++ b/src/intel/compiler/brw/brw_nir_lower_cs_intrinsics.c @@ -15,7 +15,6 @@ struct lower_intrinsics_state { /* Per-block cached values. */ bool computed; - nir_def *hw_index; nir_def *local_index; nir_def *local_id; }; @@ -24,7 +23,6 @@ static void compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr *current) { assert(!state->computed); - state->hw_index = NULL; state->local_index = NULL; state->local_id = NULL; state->computed = true; @@ -68,13 +66,8 @@ compute_local_index_id(struct lower_intrinsics_state *state, nir_intrinsic_instr nir_def *linear; if (nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK) { - /* Thread payload provides a linear index, keep track of it - * so it doesn't get removed. - */ - state->hw_index = - current->intrinsic == nir_intrinsic_load_local_invocation_index ? - ¤t->def : nir_load_local_invocation_index(b); - linear = state->hw_index; + /* Thread payload provides a linear index, just use that. */ + linear = nir_load_local_invocation_index_intel(b); } else { nir_def *subgroup_id = nir_load_subgroup_id(b); nir_def *thread_local_id = @@ -245,10 +238,6 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state, if (!state->computed) compute_local_index_id(state, intrinsic); - /* Will be lowered later by the backend. */ - if (&intrinsic->def == state->hw_index) - continue; - assert(state->local_index); sysval = state->local_index; break; diff --git a/src/intel/vulkan/anv_shader_compile.c b/src/intel/vulkan/anv_shader_compile.c index 1ae61515c4b..06077eb55d2 100644 --- a/src/intel/vulkan/anv_shader_compile.c +++ b/src/intel/vulkan/anv_shader_compile.c @@ -1564,8 +1564,7 @@ anv_shader_lower_nir(struct anv_device *device, } } - if (mesa_shader_stage_is_compute(nir->info.stage) || - mesa_shader_stage_is_mesh(nir->info.stage)) { + if (mesa_shader_stage_is_compute(nir->info.stage)) { NIR_PASS(_, nir, brw_nir_lower_cs_intrinsics, compiler->devinfo, &shader_data->prog_data.cs); }