diff --git a/src/kosmickrisp/bridge/mtl_encoder.m b/src/kosmickrisp/bridge/mtl_encoder.m index 02a1430df30..1561e7d66c4 100644 --- a/src/kosmickrisp/bridge/mtl_encoder.m +++ b/src/kosmickrisp/bridge/mtl_encoder.m @@ -208,11 +208,8 @@ mtl_dispatch_threads(mtl_compute_encoder *encoder, { @autoreleasepool { id enc = (id)encoder; - MTLSize thread_count = MTLSizeMake(grid_size.x * local_size.x, - grid_size.y * local_size.y, - grid_size.z * local_size.z); - MTLSize threads_per_threadgroup = MTLSizeMake(local_size.x, - local_size.y, + MTLSize thread_count = MTLSizeMake(grid_size.x, grid_size.y, grid_size.z); + MTLSize threads_per_threadgroup = MTLSizeMake(local_size.x, local_size.y, local_size.z); // TODO_KOSMICKRISP can we rely on nonuniform threadgroup size support? diff --git a/src/kosmickrisp/vulkan/kk_cmd_dispatch.c b/src/kosmickrisp/vulkan/kk_cmd_dispatch.c index ef4cb988d84..871f1563e25 100644 --- a/src/kosmickrisp/vulkan/kk_cmd_dispatch.c +++ b/src/kosmickrisp/vulkan/kk_cmd_dispatch.c @@ -70,13 +70,14 @@ kk_CmdDispatchBase(VkCommandBuffer commandBuffer, uint32_t baseGroupX, kk_flush_compute_state(cmd); struct kk_shader *cs = cmd->state.shaders[MESA_SHADER_COMPUTE]; + struct mtl_size local_size = cs->info.cs.local_size; struct mtl_size grid_size = { - .x = groupCountX, - .y = groupCountY, - .z = groupCountZ, + .x = groupCountX * local_size.x, + .y = groupCountY * local_size.y, + .z = groupCountZ * local_size.z, }; mtl_compute_encoder *enc = kk_compute_encoder(cmd); - mtl_dispatch_threads(enc, grid_size, cs->info.cs.local_size); + mtl_dispatch_threads(enc, grid_size, local_size); } VKAPI_ATTR void VKAPI_CALL