From 2952ae28611cbd428c1f792b7acf2586fc4c5208 Mon Sep 17 00:00:00 2001 From: Aitor Camacho Date: Wed, 1 Apr 2026 18:41:25 +0900 Subject: [PATCH] kk: Rework command buffers' compute shader state tracking Signed-off-by: Aitor Camacho Part-of: --- src/kosmickrisp/vulkan/kk_cmd_buffer.h | 5 ++--- src/kosmickrisp/vulkan/kk_cmd_dispatch.c | 10 ++++++---- src/kosmickrisp/vulkan/kk_cmd_meta.c | 12 ++++-------- src/kosmickrisp/vulkan/kk_shader.c | 6 ++---- src/kosmickrisp/vulkan/kk_shader.h | 5 +++++ 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/kosmickrisp/vulkan/kk_cmd_buffer.h b/src/kosmickrisp/vulkan/kk_cmd_buffer.h index 1eaaaf2ccd9..ed50ec70e09 100644 --- a/src/kosmickrisp/vulkan/kk_cmd_buffer.h +++ b/src/kosmickrisp/vulkan/kk_cmd_buffer.h @@ -154,9 +154,6 @@ struct kk_graphics_state { struct kk_compute_state { struct kk_descriptor_state descriptors; - mtl_compute_pipeline_state *pipeline_state; - struct mtl_size local_size; - enum kk_dirty dirty; }; struct kk_encoder; @@ -170,6 +167,8 @@ struct kk_cmd_buffer { struct { struct kk_graphics_state gfx; struct kk_compute_state cs; + struct kk_shader *shaders[MESA_SHADER_STAGES]; + uint32_t dirty_shaders; } state; /* Owned large BOs */ diff --git a/src/kosmickrisp/vulkan/kk_cmd_dispatch.c b/src/kosmickrisp/vulkan/kk_cmd_dispatch.c index 00558e60b85..65545406ba1 100644 --- a/src/kosmickrisp/vulkan/kk_cmd_dispatch.c +++ b/src/kosmickrisp/vulkan/kk_cmd_dispatch.c @@ -47,8 +47,8 @@ kk_flush_compute_state(struct kk_cmd_buffer *cmd) if (root_buffer) mtl_compute_set_buffer(enc, root_buffer->map, 0, 0); - mtl_compute_set_pipeline_state(enc, cmd->state.cs.pipeline_state); - cmd->state.cs.dirty = 0u; + mtl_compute_set_pipeline_state( + enc, cmd->state.shaders[MESA_SHADER_COMPUTE]->pipeline.cs); } VKAPI_ATTR void VKAPI_CALL @@ -69,13 +69,14 @@ kk_CmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t baseGroupX, kk_flush_compute_state(cmd); + struct kk_shader *cs = cmd->state.shaders[MESA_SHADER_COMPUTE]; struct mtl_size grid_size = { .x = groupCountX, .y = groupCountY, .z = groupCountZ, }; mtl_compute_encoder *enc = kk_compute_encoder(cmd); - mtl_dispatch_threads(enc, grid_size, cmd->state.cs.local_size); + mtl_dispatch_threads(enc, grid_size, cs->info.cs.local_size); } VKAPI_ATTR void VKAPI_CALL @@ -95,7 +96,8 @@ kk_CmdDispatchIndirect(VkCommandBuffer commandBuffer, VkBuffer _buffer, kk_flush_compute_state(cmd); + struct kk_shader *cs = cmd->state.shaders[MESA_SHADER_COMPUTE]; mtl_compute_encoder *enc = kk_compute_encoder(cmd); mtl_dispatch_threadgroups_with_indirect_buffer( - enc, buffer->mtl_handle, offset, cmd->state.cs.local_size); + enc, buffer->mtl_handle, offset, cs->info.cs.local_size); } diff --git a/src/kosmickrisp/vulkan/kk_cmd_meta.c b/src/kosmickrisp/vulkan/kk_cmd_meta.c index 11e540f7e8f..3d571be7e57 100644 --- a/src/kosmickrisp/vulkan/kk_cmd_meta.c +++ b/src/kosmickrisp/vulkan/kk_cmd_meta.c @@ -69,6 +69,7 @@ struct kk_meta_save { struct vk_vertex_input_state _dynamic_vi; struct vk_sample_locations_state _dynamic_sl; struct vk_dynamic_graphics_state dynamic; + struct kk_shader *shaders[MESA_SHADER_STAGES]; struct { union { struct { @@ -78,10 +79,6 @@ struct kk_meta_save { enum mtl_visibility_result_mode occlusion; bool is_ds_dynamic; } gfx; - struct { - mtl_compute_pipeline_state *pipeline_state; - struct mtl_size local_size; - } cs; }; } pipeline; struct kk_descriptor_set *desc0; @@ -116,8 +113,8 @@ kk_meta_begin(struct kk_cmd_buffer *cmd, struct kk_meta_save *save, cmd->state.gfx.dirty |= KK_DIRTY_OCCLUSION; desc->root_dirty = true; } else { - save->pipeline.cs.pipeline_state = cmd->state.cs.pipeline_state; - save->pipeline.cs.local_size = cmd->state.cs.local_size; + save->shaders[MESA_SHADER_COMPUTE] = + cmd->state.shaders[MESA_SHADER_COMPUTE]; } save->vb0_handle = cmd->state.gfx.vb.handles[0]; @@ -179,8 +176,7 @@ kk_meta_end(struct kk_cmd_buffer *cmd, struct kk_meta_save *save, desc->root_dirty = true; } else { - cmd->state.cs.local_size = save->pipeline.cs.local_size; - cmd->state.cs.pipeline_state = save->pipeline.cs.pipeline_state; + kk_cmd_bind_compute_shader(cmd, save->shaders[MESA_SHADER_COMPUTE]); } memcpy(desc->root.push, save->push, sizeof(save->push)); diff --git a/src/kosmickrisp/vulkan/kk_shader.c b/src/kosmickrisp/vulkan/kk_shader.c index 15c3711d6ab..8400a3f9e0d 100644 --- a/src/kosmickrisp/vulkan/kk_shader.c +++ b/src/kosmickrisp/vulkan/kk_shader.c @@ -1324,12 +1324,10 @@ kk_deserialize_shader(struct vk_device *vk_dev, struct blob_reader *blob, return VK_SUCCESS; } -static void +void kk_cmd_bind_compute_shader(struct kk_cmd_buffer *cmd, struct kk_shader *shader) { - cmd->state.cs.pipeline_state = shader->pipeline.cs; - cmd->state.cs.dirty |= KK_DIRTY_PIPELINE; - cmd->state.cs.local_size = shader->info.cs.local_size; + cmd->state.shaders[MESA_SHADER_COMPUTE] = shader; } static void diff --git a/src/kosmickrisp/vulkan/kk_shader.h b/src/kosmickrisp/vulkan/kk_shader.h index f1ecc345a1e..2dcc268b62e 100644 --- a/src/kosmickrisp/vulkan/kk_shader.h +++ b/src/kosmickrisp/vulkan/kk_shader.h @@ -17,6 +17,8 @@ #include "vk_shader.h" +struct kk_cmd_buffer; + struct kk_shader_info { mesa_shader_stage stage; union { @@ -114,4 +116,7 @@ VkResult kk_compile_nir_shader(struct kk_device *dev, nir_shader *nir, const VkAllocationCallbacks *alloc, struct kk_shader **shader_out); +void kk_cmd_bind_compute_shader(struct kk_cmd_buffer *cmd, + struct kk_shader *shader); + #endif /* KK_SHADER_H */