diff --git a/src/panfrost/vulkan/csf/panvk_cmd_buffer.h b/src/panfrost/vulkan/csf/panvk_cmd_buffer.h index 8753fd34bb5..1dc447375be 100644 --- a/src/panfrost/vulkan/csf/panvk_cmd_buffer.h +++ b/src/panfrost/vulkan/csf/panvk_cmd_buffer.h @@ -672,8 +672,8 @@ void panvk_per_arch(cmd_inherit_render_state)( static inline void panvk_per_arch(calculate_task_axis_and_increment)( const struct panvk_shader_variant *shader, - struct panvk_physical_device *phys_dev, unsigned *task_axis, - unsigned *task_increment) + struct panvk_physical_device *phys_dev, const struct pan_compute_dim *wg_dim, + unsigned *task_axis, unsigned *task_increment) { /* Pick the task_axis and task_increment to maximize thread * utilization. */ @@ -682,14 +682,17 @@ panvk_per_arch(calculate_task_axis_and_increment)( unsigned max_thread_cnt = pan_compute_max_thread_count( &phys_dev->kmod.props, shader->info.work_reg_count); unsigned threads_per_task = threads_per_wg; - unsigned local_size[3] = { - shader->cs.local_size.x, - shader->cs.local_size.y, - shader->cs.local_size.z, - }; + const unsigned wg_count[3] = {wg_dim->x, wg_dim->y, wg_dim->z}; + const unsigned total_wgs = wg_dim->x * wg_dim->y * wg_dim->z; + + if (!total_wgs) { + *task_axis = MALI_TASK_AXIS_X; + *task_increment = 1; + return; + } for (unsigned i = 0; i < 3; i++) { - if (threads_per_task * local_size[i] >= max_thread_cnt) { + if (threads_per_task * wg_count[i] >= max_thread_cnt) { /* We reached out thread limit, stop at the current axis and * calculate the increment so it doesn't exceed the per-core * thread capacity. @@ -701,11 +704,11 @@ panvk_per_arch(calculate_task_axis_and_increment)( * threads. Pick the current axis grid size as our increment * as there's no point using something bigger. */ - *task_increment = local_size[i]; + *task_increment = wg_count[i]; break; } - threads_per_task *= local_size[i]; + threads_per_task *= wg_count[i]; (*task_axis)++; } diff --git a/src/panfrost/vulkan/csf/panvk_vX_cmd_dispatch.c b/src/panfrost/vulkan/csf/panvk_vX_cmd_dispatch.c index c94c3674303..14bb1ff0854 100644 --- a/src/panfrost/vulkan/csf/panvk_vX_cmd_dispatch.c +++ b/src/panfrost/vulkan/csf/panvk_vX_cmd_dispatch.c @@ -296,7 +296,7 @@ cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info) unsigned task_axis = MALI_TASK_AXIS_X; unsigned task_increment = 0; panvk_per_arch(calculate_task_axis_and_increment)( - cs, phys_dev, &task_axis, &task_increment); + cs, phys_dev, &dim, &task_axis, &task_increment); cs_trace_run_compute(b, tracing_ctx, cs_scratch_reg_tuple(b, 0, 4), task_increment, task_axis, cs_shader_res_sel(0, 0, 0, 0)); diff --git a/src/panfrost/vulkan/csf/panvk_vX_cmd_precomp.c b/src/panfrost/vulkan/csf/panvk_vX_cmd_precomp.c index 3a5864ec203..48e842543aa 100644 --- a/src/panfrost/vulkan/csf/panvk_vX_cmd_precomp.c +++ b/src/panfrost/vulkan/csf/panvk_vX_cmd_precomp.c @@ -117,7 +117,7 @@ panvk_per_arch(dispatch_precomp)(struct panvk_precomp_ctx *ctx, unsigned task_axis = MALI_TASK_AXIS_X; unsigned task_increment = 0; panvk_per_arch(calculate_task_axis_and_increment)( - shader, phys_dev, &task_axis, &task_increment); + shader, phys_dev, &dim, &task_axis, &task_increment); cs_trace_run_compute(b, tracing_ctx, cs_scratch_reg_tuple(b, 0, 4), task_increment, task_axis, cs_shader_res_sel(0, 0, 0, 0));