kk: Rework command buffers' compute shader state tracking

Signed-off-by: Aitor Camacho <aitor@lunarg.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40763>
This commit is contained in:
Aitor Camacho 2026-04-01 18:41:25 +09:00 committed by Marge Bot
parent 8444723c6a
commit 2952ae2861
5 changed files with 19 additions and 19 deletions

View file

@ -154,9 +154,6 @@ struct kk_graphics_state {
struct kk_compute_state {
struct kk_descriptor_state descriptors;
mtl_compute_pipeline_state *pipeline_state;
struct mtl_size local_size;
enum kk_dirty dirty;
};
struct kk_encoder;
@ -170,6 +167,8 @@ struct kk_cmd_buffer {
struct {
struct kk_graphics_state gfx;
struct kk_compute_state cs;
struct kk_shader *shaders[MESA_SHADER_STAGES];
uint32_t dirty_shaders;
} state;
/* Owned large BOs */

View file

@ -47,8 +47,8 @@ kk_flush_compute_state(struct kk_cmd_buffer *cmd)
if (root_buffer)
mtl_compute_set_buffer(enc, root_buffer->map, 0, 0);
mtl_compute_set_pipeline_state(enc, cmd->state.cs.pipeline_state);
cmd->state.cs.dirty = 0u;
mtl_compute_set_pipeline_state(
enc, cmd->state.shaders[MESA_SHADER_COMPUTE]->pipeline.cs);
}
VKAPI_ATTR void VKAPI_CALL
@ -69,13 +69,14 @@ kk_CmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t baseGroupX,
kk_flush_compute_state(cmd);
struct kk_shader *cs = cmd->state.shaders[MESA_SHADER_COMPUTE];
struct mtl_size grid_size = {
.x = groupCountX,
.y = groupCountY,
.z = groupCountZ,
};
mtl_compute_encoder *enc = kk_compute_encoder(cmd);
mtl_dispatch_threads(enc, grid_size, cmd->state.cs.local_size);
mtl_dispatch_threads(enc, grid_size, cs->info.cs.local_size);
}
VKAPI_ATTR void VKAPI_CALL
@ -95,7 +96,8 @@ kk_CmdDispatchIndirect(VkCommandBuffer commandBuffer, VkBuffer _buffer,
kk_flush_compute_state(cmd);
struct kk_shader *cs = cmd->state.shaders[MESA_SHADER_COMPUTE];
mtl_compute_encoder *enc = kk_compute_encoder(cmd);
mtl_dispatch_threadgroups_with_indirect_buffer(
enc, buffer->mtl_handle, offset, cmd->state.cs.local_size);
enc, buffer->mtl_handle, offset, cs->info.cs.local_size);
}

View file

@ -69,6 +69,7 @@ struct kk_meta_save {
struct vk_vertex_input_state _dynamic_vi;
struct vk_sample_locations_state _dynamic_sl;
struct vk_dynamic_graphics_state dynamic;
struct kk_shader *shaders[MESA_SHADER_STAGES];
struct {
union {
struct {
@ -78,10 +79,6 @@ struct kk_meta_save {
enum mtl_visibility_result_mode occlusion;
bool is_ds_dynamic;
} gfx;
struct {
mtl_compute_pipeline_state *pipeline_state;
struct mtl_size local_size;
} cs;
};
} pipeline;
struct kk_descriptor_set *desc0;
@ -116,8 +113,8 @@ kk_meta_begin(struct kk_cmd_buffer *cmd, struct kk_meta_save *save,
cmd->state.gfx.dirty |= KK_DIRTY_OCCLUSION;
desc->root_dirty = true;
} else {
save->pipeline.cs.pipeline_state = cmd->state.cs.pipeline_state;
save->pipeline.cs.local_size = cmd->state.cs.local_size;
save->shaders[MESA_SHADER_COMPUTE] =
cmd->state.shaders[MESA_SHADER_COMPUTE];
}
save->vb0_handle = cmd->state.gfx.vb.handles[0];
@ -179,8 +176,7 @@ kk_meta_end(struct kk_cmd_buffer *cmd, struct kk_meta_save *save,
desc->root_dirty = true;
} else {
cmd->state.cs.local_size = save->pipeline.cs.local_size;
cmd->state.cs.pipeline_state = save->pipeline.cs.pipeline_state;
kk_cmd_bind_compute_shader(cmd, save->shaders[MESA_SHADER_COMPUTE]);
}
memcpy(desc->root.push, save->push, sizeof(save->push));

View file

@ -1324,12 +1324,10 @@ kk_deserialize_shader(struct vk_device *vk_dev, struct blob_reader *blob,
return VK_SUCCESS;
}
static void
void
kk_cmd_bind_compute_shader(struct kk_cmd_buffer *cmd, struct kk_shader *shader)
{
cmd->state.cs.pipeline_state = shader->pipeline.cs;
cmd->state.cs.dirty |= KK_DIRTY_PIPELINE;
cmd->state.cs.local_size = shader->info.cs.local_size;
cmd->state.shaders[MESA_SHADER_COMPUTE] = shader;
}
static void

View file

@ -17,6 +17,8 @@
#include "vk_shader.h"
struct kk_cmd_buffer;
struct kk_shader_info {
mesa_shader_stage stage;
union {
@ -114,4 +116,7 @@ VkResult kk_compile_nir_shader(struct kk_device *dev, nir_shader *nir,
const VkAllocationCallbacks *alloc,
struct kk_shader **shader_out);
void kk_cmd_bind_compute_shader(struct kk_cmd_buffer *cmd,
struct kk_shader *shader);
#endif /* KK_SHADER_H */