diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index 0355d9f406b..543e8841912 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -1893,10 +1893,6 @@ radv_emit_graphics_pipeline(struct radv_cmd_buffer *cmd_buffer) if (cmd_buffer->state.emitted_graphics_pipeline == pipeline) return; - cmd_buffer->scratch_size_per_wave_needed = - MAX2(cmd_buffer->scratch_size_per_wave_needed, pipeline->base.scratch_bytes_per_wave); - cmd_buffer->scratch_waves_wanted = MAX2(cmd_buffer->scratch_waves_wanted, pipeline->base.max_waves); - if (cmd_buffer->state.emitted_graphics_pipeline) { if (radv_rast_prim_is_points_or_lines(cmd_buffer->state.emitted_graphics_pipeline->rast_prim) != radv_rast_prim_is_points_or_lines(pipeline->rast_prim)) cmd_buffer->state.dirty |= RADV_CMD_DIRTY_GUARDBAND; @@ -6128,6 +6124,10 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline if (graphics_pipeline->gsvs_ring_size > cmd_buffer->gsvs_ring_size_needed) cmd_buffer->gsvs_ring_size_needed = graphics_pipeline->gsvs_ring_size; + cmd_buffer->scratch_size_per_wave_needed = + MAX2(cmd_buffer->scratch_size_per_wave_needed, pipeline->scratch_bytes_per_wave); + cmd_buffer->scratch_waves_wanted = MAX2(cmd_buffer->scratch_waves_wanted, pipeline->max_waves); + if (radv_pipeline_has_stage(graphics_pipeline, MESA_SHADER_TESS_CTRL)) cmd_buffer->tess_rings_needed = true; if (mesh_shading)