panvk/csf: Add a helper for dispatching compute shaders with 3D state

Co-authored-by: Olivia Lee <olivia.lee@collabora.com>
Signed-off-by: Olivia Lee <olivia.lee@collabora.com>
Reviewed-by: Eric R. Smith <eric.smith@collabora.com>
Reviewed-by: Christian Gmeiner <cgmeiner@igalia.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41654>
This commit is contained in:
Faith Ekstrand 2025-11-10 15:52:53 -05:00 committed by Marge Bot
parent 412f7cb999
commit 8172e3d795
3 changed files with 123 additions and 60 deletions

View file

@ -837,6 +837,13 @@ panvk_per_arch(calculate_task_axis_and_increment)(
assert(*task_increment > 0);
}
void panvk_per_arch(cmd_dispatch_shader)(
struct panvk_cmd_buffer *cmdbuf,
const struct panvk_shader_variant *cs,
const struct panvk_shader_desc_state *cs_desc_state,
uint64_t push_uniforms, uint64_t tsd,
const struct panvk_dispatch_info *info);
static VkPipelineStageFlags2
panvk_get_subqueue_stages(enum panvk_subqueue_id subqueue)
{

View file

@ -119,25 +119,16 @@ panvk_per_arch(cmd_dispatch_prepare_tls)(
return tsd.gpu;
}
static void
cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
void
panvk_per_arch(cmd_dispatch_shader)(
struct panvk_cmd_buffer *cmdbuf,
const struct panvk_shader_variant *cs,
const struct panvk_shader_desc_state *cs_desc_state,
uint64_t push_uniforms, uint64_t tsd,
const struct panvk_dispatch_info *info)
{
const struct panvk_shader_variant *cs =
panvk_shader_only_variant(cmdbuf->state.compute.shader);
VkResult result;
/* If there's no compute shader, we can skip the dispatch. */
if (!panvk_priv_mem_check_alloc(cs->spd))
return;
struct panvk_physical_device *phys_dev =
to_panvk_physical_device(cmdbuf->vk.base.device->physical);
const struct panvk_shader_desc_info *cs_desc_info =
&cmdbuf->state.compute.shader->desc_info;
struct panvk_descriptor_state *desc_state =
&cmdbuf->state.compute.desc_state;
struct panvk_shader_desc_state *cs_desc_state =
&cmdbuf->state.compute.cs.desc;
const struct cs_tracing_ctx *tracing_ctx =
&cmdbuf->state.cs[PANVK_SUBQUEUE_COMPUTE].tracing;
@ -148,55 +139,14 @@ cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
};
bool indirect = info->indirect.buffer_dev_addr != 0;
uint64_t tsd =
panvk_per_arch(cmd_dispatch_prepare_tls)(cmdbuf, cs, &dim, indirect);
if (!tsd)
return;
/* Only used for indirect dispatch */
unsigned wg_per_task = 0;
if (indirect)
wg_per_task = pan_calc_workgroups_per_task(&cs->cs.local_size,
&phys_dev->kmod.dev->props);
if (compute_state_dirty(cmdbuf, DESC_STATE) ||
compute_state_dirty(cmdbuf, CS)) {
result = panvk_per_arch(cmd_prepare_push_descs)(
cmdbuf, desc_state, cs_desc_info->used_set_mask);
if (result != VK_SUCCESS)
return;
}
panvk_per_arch(cmd_prepare_dispatch_sysvals)(cmdbuf, info);
result = prepare_driver_set(cmdbuf);
if (result != VK_SUCCESS)
return;
result = panvk_per_arch(cmd_prepare_compute_push_uniforms)(
cmdbuf, cs, &cmdbuf->state.compute.push_uniforms);
if (result != VK_SUCCESS)
return;
if (compute_state_dirty(cmdbuf, CS) ||
compute_state_dirty(cmdbuf, DESC_STATE)) {
result = panvk_per_arch(cmd_prepare_shader_res_table)(
cmdbuf, desc_state, cs_desc_info, cs_desc_state, 1);
if (result != VK_SUCCESS)
return;
}
struct cs_builder *b = panvk_get_cs_builder(cmdbuf, PANVK_SUBQUEUE_COMPUTE);
/* Copy the global TLS pointer to the per-job TSD. */
if (cs->info.tls_size) {
cs_move64_to(b, cs_scratch_reg64(b, 0), cmdbuf->state.tls.desc.gpu);
cs_load64_to(b, cs_scratch_reg64(b, 2), cs_scratch_reg64(b, 0), 8);
cs_move64_to(b, cs_scratch_reg64(b, 0), tsd);
cs_store64(b, cs_scratch_reg64(b, 2), cs_scratch_reg64(b, 0), 8);
cs_flush_stores(b);
}
cs_update_compute_ctx(b) {
if (compute_state_dirty(cmdbuf, CS) ||
compute_state_dirty(cmdbuf, DESC_STATE))
@ -204,7 +154,7 @@ cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
cs_desc_state->res_table);
if (compute_state_dirty(cmdbuf, PUSH_UNIFORMS)) {
uint64_t fau_ptr = cmdbuf->state.compute.push_uniforms |
uint64_t fau_ptr = push_uniforms |
((uint64_t)cs->fau.total_count << 56);
cs_move64_to(b, cs_reg64(b, PANVK_COMPUTE_FAU), fau_ptr);
}
@ -244,8 +194,7 @@ cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
info->indirect.buffer_dev_addr);
cs_load_to(b, cs_sr_reg_tuple(b, COMPUTE, JOB_SIZE_X, 3),
cs_scratch_reg64(b, 0), BITFIELD_MASK(3), 0);
cs_move64_to(b, cs_scratch_reg64(b, 0),
cmdbuf->state.compute.push_uniforms);
cs_move64_to(b, cs_scratch_reg64(b, 0), push_uniforms);
if (shader_uses_sysval(cs, compute, num_work_groups.x)) {
cs_store32(b, cs_sr_reg32(b, COMPUTE, JOB_SIZE_X),
@ -346,6 +295,79 @@ cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
clear_dirty_after_dispatch(cmdbuf);
}
static void
cmd_dispatch(struct panvk_cmd_buffer *cmdbuf, struct panvk_dispatch_info *info)
{
const struct panvk_shader_variant *cs =
panvk_shader_only_variant(cmdbuf->state.compute.shader);
VkResult result;
/* If there's no compute shader, we can skip the dispatch. */
if (!panvk_priv_mem_check_alloc(cs->spd))
return;
const struct panvk_shader_desc_info *cs_desc_info =
&cmdbuf->state.compute.shader->desc_info;
struct panvk_descriptor_state *desc_state =
&cmdbuf->state.compute.desc_state;
struct panvk_shader_desc_state *cs_desc_state =
&cmdbuf->state.compute.cs.desc;
struct pan_compute_dim dim = {
info->direct.wg_count.x,
info->direct.wg_count.y,
info->direct.wg_count.z,
};
bool indirect = info->indirect.buffer_dev_addr != 0;
uint64_t tsd =
panvk_per_arch(cmd_dispatch_prepare_tls)(cmdbuf, cs, &dim, indirect);
if (!tsd)
return;
if (compute_state_dirty(cmdbuf, DESC_STATE) ||
compute_state_dirty(cmdbuf, CS)) {
result = panvk_per_arch(cmd_prepare_push_descs)(
cmdbuf, desc_state, cs_desc_info->used_set_mask);
if (result != VK_SUCCESS)
return;
}
panvk_per_arch(cmd_prepare_dispatch_sysvals)(cmdbuf, info);
result = prepare_driver_set(cmdbuf);
if (result != VK_SUCCESS)
return;
result = panvk_per_arch(cmd_prepare_compute_push_uniforms)(
cmdbuf, cs, &cmdbuf->state.compute.push_uniforms);
if (result != VK_SUCCESS)
return;
if (compute_state_dirty(cmdbuf, CS) ||
compute_state_dirty(cmdbuf, DESC_STATE)) {
result = panvk_per_arch(cmd_prepare_shader_res_table)(
cmdbuf, desc_state, cs_desc_info, cs_desc_state, 1);
if (result != VK_SUCCESS)
return;
}
struct cs_builder *b = panvk_get_cs_builder(cmdbuf, PANVK_SUBQUEUE_COMPUTE);
/* Copy the global TLS pointer to the per-job TSD. */
if (cs->info.tls_size) {
cs_move64_to(b, cs_scratch_reg64(b, 0), cmdbuf->state.tls.desc.gpu);
cs_load64_to(b, cs_scratch_reg64(b, 2), cs_scratch_reg64(b, 0), 8);
cs_move64_to(b, cs_scratch_reg64(b, 0), tsd);
cs_store64(b, cs_scratch_reg64(b, 2), cs_scratch_reg64(b, 0), 8);
cs_flush_stores(b);
}
panvk_per_arch(cmd_dispatch_shader)(cmdbuf, cs, cs_desc_state,
cmdbuf->state.compute.push_uniforms,
tsd, info);
}
VKAPI_ATTR void VKAPI_CALL
panvk_per_arch(CmdDispatchBase)(VkCommandBuffer commandBuffer,
uint32_t baseGroupX, uint32_t baseGroupY,

View file

@ -2696,6 +2696,40 @@ update_prims_generated_query(struct panvk_cmd_buffer *cmdbuf,
}
}
static void
launch_gfx_cs(struct panvk_cmd_buffer *cmdbuf,
const struct panvk_shader_variant *cs,
const struct panvk_shader_desc_state *cs_desc_state,
uint64_t push_uniforms,
const struct panvk_dispatch_info *info)
{
/* For GFX compute shaders, we re-emit push_uniforms every time because it
* massively simplifies the interface. Also, they basically always contain
* draw info which changes every draw anyway so dirty checks won't actually
* save us anything.
*/
if (!push_uniforms) {
VkResult result = panvk_per_arch(cmd_prepare_gfx_push_uniforms)(
cmdbuf, cs, &push_uniforms, 1);
if (result != VK_SUCCESS)
return;
}
/* Dirty everything */
compute_state_set_dirty(cmdbuf, CS);
compute_state_set_dirty(cmdbuf, DESC_STATE);
compute_state_set_dirty(cmdbuf, PUSH_UNIFORMS);
panvk_per_arch(cmd_dispatch_shader)(cmdbuf, cs, cs_desc_state,
push_uniforms, cmdbuf->state.gfx.tsd,
info);
/* Dirty everything */
compute_state_set_dirty(cmdbuf, CS);
compute_state_set_dirty(cmdbuf, DESC_STATE);
compute_state_set_dirty(cmdbuf, PUSH_UNIFORMS);
}
static void
launch_draw(struct panvk_cmd_buffer *cmdbuf,
const struct panvk_draw_info *draw)