diff --git a/src/amd/vulkan/radv_shader_object.c b/src/amd/vulkan/radv_shader_object.c index 5397037ffbd..e36196a6dc9 100644 --- a/src/amd/vulkan/radv_shader_object.c +++ b/src/amd/vulkan/radv_shader_object.c @@ -150,7 +150,36 @@ radv_shader_object_init_graphics(struct radv_shader_object *shader_obj, struct r struct radv_shader *shader = NULL; struct radv_shader_binary *binary = NULL; - if (!pCreateInfo->nextStage) { + VkShaderStageFlags next_stages = pCreateInfo->nextStage; + if (!next_stages) { + /* When next stage is 0, gather all valid next stages. */ + switch (pCreateInfo->stage) { + case VK_SHADER_STAGE_VERTEX_BIT: + next_stages |= + VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT; + break; + case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: + next_stages |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + break; + case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: + next_stages |= VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT; + break; + case VK_SHADER_STAGE_GEOMETRY_BIT: + case VK_SHADER_STAGE_MESH_BIT_EXT: + next_stages |= VK_SHADER_STAGE_FRAGMENT_BIT; + break; + case VK_SHADER_STAGE_TASK_BIT_EXT: + next_stages |= VK_SHADER_STAGE_MESH_BIT_EXT; + break; + case VK_SHADER_STAGE_FRAGMENT_BIT: + case VK_SHADER_STAGE_COMPUTE_BIT: + break; + default: + unreachable("Invalid shader stage"); + } + } + + if (!next_stages) { struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES] = {NULL}; struct radv_shader_binary *binaries[MESA_VULKAN_SHADER_STAGES] = {NULL}; @@ -165,7 +194,7 @@ radv_shader_object_init_graphics(struct radv_shader_object *shader_obj, struct r shader_obj->shader = shader; shader_obj->binary = binary; } else { - radv_foreach_stage(next_stage, pCreateInfo->nextStage) + radv_foreach_stage(next_stage, next_stages) { struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES] = {NULL}; struct radv_shader_binary *binaries[MESA_VULKAN_SHADER_STAGES] = {NULL};