diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index 8a0bfbd951b..6f05f73d334 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -6536,6 +6536,12 @@ radv_bind_geometry_shader(struct radv_cmd_buffer *cmd_buffer, const struct radv_ cmd_buffer->state.emitted_vs_prolog = NULL; } +static void +radv_bind_gs_copy_shader(struct radv_cmd_buffer *cmd_buffer, struct radv_shader *gs_copy_shader) +{ + cmd_buffer->state.gs_copy_shader = gs_copy_shader; +} + static void radv_bind_mesh_shader(struct radv_cmd_buffer *cmd_buffer, const struct radv_shader *ms) { @@ -6601,6 +6607,15 @@ radv_bind_task_shader(struct radv_cmd_buffer *cmd_buffer, const struct radv_shad cmd_buffer->task_rings_needed = true; } +static void +radv_bind_rt_prolog(struct radv_cmd_buffer *cmd_buffer, struct radv_shader *rt_prolog) +{ + cmd_buffer->state.rt_prolog = rt_prolog; + + const unsigned max_scratch_waves = radv_get_max_scratch_waves(cmd_buffer->device, rt_prolog); + cmd_buffer->compute_scratch_waves_wanted = MAX2(cmd_buffer->compute_scratch_waves_wanted, max_scratch_waves); +} + /* This function binds/unbinds a shader to the cmdbuffer state. */ static void radv_bind_shader(struct radv_cmd_buffer *cmd_buffer, struct radv_shader *shader, gl_shader_stage stage) @@ -6729,7 +6744,7 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint); radv_bind_shader(cmd_buffer, rt_pipeline->base.base.shaders[MESA_SHADER_INTERSECTION], MESA_SHADER_INTERSECTION); - cmd_buffer->state.rt_prolog = rt_pipeline->prolog; + radv_bind_rt_prolog(cmd_buffer, rt_pipeline->prolog); cmd_buffer->state.rt_pipeline = rt_pipeline; cmd_buffer->push_constant_stages |= RADV_RT_STAGE_BITS; @@ -6738,8 +6753,6 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline if (rt_pipeline->stack_size != -1u) cmd_buffer->state.rt_stack_size = rt_pipeline->stack_size; - const unsigned max_scratch_waves = radv_get_max_scratch_waves(cmd_buffer->device, rt_pipeline->prolog); - cmd_buffer->compute_scratch_waves_wanted = MAX2(cmd_buffer->compute_scratch_waves_wanted, max_scratch_waves); break; } case VK_PIPELINE_BIND_POINT_GRAPHICS: { @@ -6760,7 +6773,8 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline radv_bind_shader(cmd_buffer, graphics_pipeline->base.shaders[stage], stage); } - cmd_buffer->state.gs_copy_shader = graphics_pipeline->base.gs_copy_shader; + radv_bind_gs_copy_shader(cmd_buffer, graphics_pipeline->base.gs_copy_shader); + cmd_buffer->state.last_vgt_shader = graphics_pipeline->base.shaders[graphics_pipeline->last_vgt_api_stage]; cmd_buffer->state.graphics_pipeline = graphics_pipeline; @@ -9375,9 +9389,12 @@ radv_bind_graphics_shaders(struct radv_cmd_buffer *cmd_buffer) cmd_buffer->state.last_vgt_shader = cmd_buffer->state.shaders[last_vgt_api_stage]; - cmd_buffer->state.gs_copy_shader = cmd_buffer->state.shader_objs[MESA_SHADER_GEOMETRY] - ? cmd_buffer->state.shader_objs[MESA_SHADER_GEOMETRY]->gs.copy_shader - : NULL; + struct radv_shader *gs_copy_shader = cmd_buffer->state.shader_objs[MESA_SHADER_GEOMETRY] + ? cmd_buffer->state.shader_objs[MESA_SHADER_GEOMETRY]->gs.copy_shader + : NULL; + + radv_bind_gs_copy_shader(cmd_buffer, gs_copy_shader); + if (cmd_buffer->state.gs_copy_shader) { radv_cs_add_buffer(cmd_buffer->device->ws, cmd_buffer->cs, cmd_buffer->state.gs_copy_shader->bo); }