diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h index 1336a67560f..e6315d4a2f0 100644 --- a/src/amd/vulkan/nir/radv_nir.h +++ b/src/amd/vulkan/nir/radv_nir.h @@ -84,6 +84,8 @@ void radv_nir_lower_poly_line_smooth(nir_shader *nir, const struct radv_graphics bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size); +bool radv_nir_lower_draw_id_to_zero(nir_shader *shader); + #ifdef __cplusplus } #endif diff --git a/src/amd/vulkan/nir/radv_nir_lower_io.c b/src/amd/vulkan/nir/radv_nir_lower_io.c index e78523a7a49..d9cebde0367 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_io.c +++ b/src/amd/vulkan/nir/radv_nir_lower_io.c @@ -179,3 +179,24 @@ radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *s return false; } + +static bool +radv_nir_lower_draw_id_to_zero_callback(struct nir_builder *b, nir_intrinsic_instr *intrin, UNUSED void *state) +{ + if (intrin->intrinsic != nir_intrinsic_load_draw_id) + return false; + + nir_def *replacement = nir_imm_zero(b, intrin->def.num_components, intrin->def.bit_size); + nir_def_rewrite_uses(&intrin->def, replacement); + nir_instr_remove(&intrin->instr); + nir_instr_free(&intrin->instr); + + return true; +} + +bool +radv_nir_lower_draw_id_to_zero(nir_shader *shader) +{ + return nir_shader_intrinsics_pass(shader, radv_nir_lower_draw_id_to_zero_callback, + nir_metadata_block_index | nir_metadata_dominance, NULL); +} diff --git a/src/amd/vulkan/radv_pipeline_graphics.c b/src/amd/vulkan/radv_pipeline_graphics.c index e1dd6bfd7c9..b2d9bf1fdea 100644 --- a/src/amd/vulkan/radv_pipeline_graphics.c +++ b/src/amd/vulkan/radv_pipeline_graphics.c @@ -1580,6 +1580,10 @@ radv_link_mesh(const struct radv_device *device, struct radv_shader_stage *mesh_ nir_foreach_shader_out_variable (var, mesh_stage->nir) { var->data.driver_location = 0; } + + /* Lower mesh shader draw ID to zero prevent app bugs from triggering undefined behaviour. */ + if (mesh_stage->info.ms.has_task && BITSET_TEST(mesh_stage->nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID)) + radv_nir_lower_draw_id_to_zero(mesh_stage->nir); } static void