diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index 5a07d652a0e..14b2e4ad202 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -2805,9 +2805,10 @@ radv_fill_shader_keys(struct radv_device *device, struct radv_shader_variant_key static uint8_t radv_get_wave_size(struct radv_device *device, const VkPipelineShaderStageCreateInfo *pStage, - gl_shader_stage stage, const struct radv_shader_variant_key *key) + gl_shader_stage stage, const struct radv_shader_variant_key *key, + const struct radv_shader_info *info) { - if (stage == MESA_SHADER_GEOMETRY && !key->vs_common_out.as_ngg) + if (stage == MESA_SHADER_GEOMETRY && !info->is_ngg) return 64; else if (stage == MESA_SHADER_COMPUTE) { return key->cs.subgroup_size; @@ -2918,7 +2919,7 @@ radv_fill_shader_info(struct radv_pipeline *pipeline, for (int i = 0; i < MESA_SHADER_STAGES; i++) { if (nir[i]) { - infos[i].wave_size = radv_get_wave_size(pipeline->device, pStages[i], i, &keys[i]); + infos[i].wave_size = radv_get_wave_size(pipeline->device, pStages[i], i, &keys[i], &infos[i]); infos[i].ballot_bit_size = radv_get_ballot_bit_size(pipeline->device, pStages[i], i, &keys[i]); }