diff --git a/.pick_status.json b/.pick_status.json index e5a5ebe6e06..1b7fa266b6f 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -1304,7 +1304,7 @@ "description": "radv: fix local invocation index for mesh/task and quad derivatives on GFX12", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index fbe029bea13..19bc126f5bf 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -655,15 +655,24 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options); } + bool lower_local_invocation_index = false; + + if (nir->info.derivative_group == DERIVATIVE_GROUP_QUADS && + ((nir->info.stage == MESA_SHADER_COMPUTE || nir->info.stage == MESA_SHADER_TASK || + (nir->info.stage == MESA_SHADER_MESH && pdev->info.mesh_fast_launch_2)))) { + lower_local_invocation_index = true; + } else if (nir->info.stage == MESA_SHADER_COMPUTE && + (((nir->info.workgroup_size[0] == 1) + (nir->info.workgroup_size[1] == 1) + + (nir->info.workgroup_size[2] == 1)) == 2)) { + lower_local_invocation_index = true; + } + nir_lower_compute_system_values_options csv_options = { /* Mesh shaders run as NGG which can implement local_invocation_index from * the wave ID in merged_wave_info, but they don't have local_invocation_ids on GFX10.3. */ .lower_cs_local_id_to_index = nir->info.stage == MESA_SHADER_MESH && !pdev->info.mesh_fast_launch_2, - .lower_local_invocation_index = nir->info.stage == MESA_SHADER_COMPUTE && - ((((nir->info.workgroup_size[0] == 1) + (nir->info.workgroup_size[1] == 1) + - (nir->info.workgroup_size[2] == 1)) == 2) || - nir->info.derivative_group == DERIVATIVE_GROUP_QUADS), + .lower_local_invocation_index = lower_local_invocation_index, }; NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options);