diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index e81630fa7ed..705a87adcaa 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -13406,21 +13406,29 @@ radv_after_dispatch(struct radv_cmd_buffer *cmd_buffer, VkPipelineBindPoint bind static void radv_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info, - struct radv_compute_pipeline *pipeline, VkPipelineBindPoint bind_point) + struct radv_compute_pipeline *pipeline, const struct radv_shader *shader, VkPipelineBindPoint bind_point) { - struct radv_shader *compute_shader = bind_point == VK_PIPELINE_BIND_POINT_COMPUTE - ? cmd_buffer->state.shaders[MESA_SHADER_COMPUTE] - : cmd_buffer->state.rt_prolog; - radv_before_dispatch(cmd_buffer, pipeline, bind_point); - radv_emit_dispatch_packets(cmd_buffer, compute_shader, info); + radv_emit_dispatch_packets(cmd_buffer, shader, info); radv_after_dispatch(cmd_buffer, bind_point, false); } void radv_compute_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info) { - radv_dispatch(cmd_buffer, info, cmd_buffer->state.compute_pipeline, VK_PIPELINE_BIND_POINT_COMPUTE); + struct radv_compute_pipeline *pipeline = cmd_buffer->state.compute_pipeline; + const struct radv_shader *shader = cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]; + + radv_dispatch(cmd_buffer, info, pipeline, shader, VK_PIPELINE_BIND_POINT_COMPUTE); +} + +static void +radv_rt_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info) +{ + struct radv_compute_pipeline *pipeline = &cmd_buffer->state.rt_pipeline->base; + const struct radv_shader *shader = cmd_buffer->state.rt_prolog; + + radv_dispatch(cmd_buffer, info, pipeline, shader, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR); } VKAPI_ATTR void VKAPI_CALL @@ -13575,7 +13583,6 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K if (unlikely(device->rra_trace.ray_history_buffer)) radv_trace_trace_rays(cmd_buffer, tables, indirect_va); - struct radv_compute_pipeline *pipeline = &cmd_buffer->state.rt_pipeline->base; struct radv_shader *rt_prolog = cmd_buffer->state.rt_prolog; /* Since the workgroup size is 8x4 (or 8x8), 1D dispatches can only fill 8 threads per wave at most. To increase @@ -13636,7 +13643,7 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K assert(cs->b->cdw <= cdw_max); - radv_dispatch(cmd_buffer, &info, pipeline, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR); + radv_rt_dispatch(cmd_buffer, &info); if (remaining_ray_count) { info.blocks[0] = remaining_ray_count; @@ -13652,7 +13659,7 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, VkTraceRaysIndirectCommand2K radeon_end(); } - radv_dispatch(cmd_buffer, &info, pipeline, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR); + radv_rt_dispatch(cmd_buffer, &info); } radv_resume_conditional_rendering(cmd_buffer);