diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c index ab02a912808..2d7e2385a46 100644 --- a/src/amd/vulkan/radv_shader_info.c +++ b/src/amd/vulkan/radv_shader_info.c @@ -565,11 +565,15 @@ radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *n info->cs.uses_local_invocation_idx = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_LOCAL_INVOCATION_INDEX) | BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_SUBGROUP_ID) | BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_NUM_SUBGROUPS); + + if (nir->info.stage == MESA_SHADER_COMPUTE || nir->info.stage == MESA_SHADER_TASK) { + for (int i = 0; i < 3; ++i) + info->cs.block_size[i] = nir->info.workgroup_size[i]; + } + switch (nir->info.stage) { case MESA_SHADER_COMPUTE: case MESA_SHADER_TASK: - for (int i = 0; i < 3; ++i) - info->cs.block_size[i] = nir->info.workgroup_size[i]; info->cs.uses_ray_launch_size = BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_RAY_LAUNCH_SIZE_ADDR_AMD); /* Task shaders always need these for the I/O lowering even if