diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 7fd7d728cb2..3ddfc3b9a6a 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -662,15 +662,24 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options); } + bool lower_local_invocation_index = false; + + if (nir->info.derivative_group == DERIVATIVE_GROUP_QUADS && + ((nir->info.stage == MESA_SHADER_COMPUTE || nir->info.stage == MESA_SHADER_TASK || + (nir->info.stage == MESA_SHADER_MESH && pdev->info.mesh_fast_launch_2)))) { + lower_local_invocation_index = true; + } else if (nir->info.stage == MESA_SHADER_COMPUTE && + (((nir->info.workgroup_size[0] == 1) + (nir->info.workgroup_size[1] == 1) + + (nir->info.workgroup_size[2] == 1)) == 2)) { + lower_local_invocation_index = true; + } + nir_lower_compute_system_values_options csv_options = { /* Mesh shaders run as NGG which can implement local_invocation_index from * the wave ID in merged_wave_info, but they don't have local_invocation_ids on GFX10.3. */ .lower_cs_local_id_to_index = nir->info.stage == MESA_SHADER_MESH && !pdev->info.mesh_fast_launch_2, - .lower_local_invocation_index = nir->info.stage == MESA_SHADER_COMPUTE && - ((((nir->info.workgroup_size[0] == 1) + (nir->info.workgroup_size[1] == 1) + - (nir->info.workgroup_size[2] == 1)) == 2) || - nir->info.derivative_group == DERIVATIVE_GROUP_QUADS), + .lower_local_invocation_index = lower_local_invocation_index, }; NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options);