zink: handle drivers with multiple subgroup sizes correctly

This makes use of VK_KHR_pipeline_executable_properties and
VK_EXT_subgroup_size_control to query supported subgroup sizes and to
query which one the driver picks for a given pipeline.

This fixes OpenCL subgroups on drivers with multiple supported subgroup
sizes.

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37169>
This commit is contained in:
Karol Herbst 2025-09-03 11:41:15 +02:00 committed by Marge Bot
parent 38ef732169
commit dee90bf617
5 changed files with 108 additions and 3 deletions

View file

@ -5673,7 +5673,14 @@ zink_shader_init(struct zink_screen *screen, struct zink_shader *zs)
{
nir_lower_subgroups_options subgroup_options = {0};
subgroup_options.lower_to_scalar = true;
subgroup_options.subgroup_size = screen->info.props11.subgroupSize;
if (nir->info.api_subgroup_size)
subgroup_options.subgroup_size = nir->info.api_subgroup_size;
else if (nir->info.stage != MESA_SHADER_KERNEL ||
!screen->info.feats13.subgroupSizeControl ||
screen->info.props13.minSubgroupSize == screen->info.props13.maxSubgroupSize)
subgroup_options.subgroup_size = screen->info.props11.subgroupSize;
subgroup_options.ballot_bit_size = 32;
subgroup_options.ballot_components = 4;
subgroup_options.lower_subgroup_masks = true;

View file

@ -489,6 +489,19 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro
STATIC_ASSERT(ARRAY_SIZE(data) == ARRAY_SIZE(me));
}
/* pin the subgroup size whenever the driver can't tell us on compute kernels. */
VkPipelineShaderStageRequiredSubgroupSizeCreateInfo subInfo = {
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
};
if (comp->shader->info.stage == MESA_SHADER_KERNEL &&
screen->info.feats13.subgroupSizeControl &&
(screen->info.props13.requiredSubgroupSizeStages & VK_SHADER_STAGE_COMPUTE_BIT) &&
screen->info.props13.minSubgroupSize != screen->info.props13.maxSubgroupSize &&
!screen->info.have_KHR_pipeline_executable_properties) {
subInfo.requiredSubgroupSize = zink_get_subgroup_size_for_block(screen, comp, state->local_size);
stage.pNext = &subInfo;
}
pci.stage = stage;
VkPipeline pipeline;

View file

@ -2339,6 +2339,67 @@ zink_bind_cs_state(struct pipe_context *pctx,
zink_select_launch_grid(ctx);
}
static uint32_t
zink_get_subgroup_sizes_for_pipeline(struct zink_screen *screen, VkPipeline pipeline)
{
struct VkPipelineInfoKHR info = {
.sType = VK_STRUCTURE_TYPE_PIPELINE_INFO_KHR,
.pipeline = pipeline,
};
uint32_t stat_count;
uint32_t subgroup_sizes = 0;
VKSCR(GetPipelineExecutablePropertiesKHR)(screen->dev, &info, &stat_count, NULL);
struct VkPipelineExecutablePropertiesKHR *stats = CALLOC(stat_count, sizeof(*stats));
VKSCR(GetPipelineExecutablePropertiesKHR)(screen->dev, &info, &stat_count, stats);
for (unsigned i = 0; i < stat_count; i++)
subgroup_sizes |= stats[i].subgroupSize;
FREE(stats);
return subgroup_sizes;
}
uint32_t
zink_get_subgroup_size_for_block(struct zink_screen *screen, struct zink_compute_program *comp,
const uint32_t block[3])
{
if (!screen->info.feats13.subgroupSizeControl)
return screen->base.compute_caps.subgroup_sizes;
if (screen->info.props13.minSubgroupSize == screen->info.props13.maxSubgroupSize)
return screen->info.props13.minSubgroupSize;
uint32_t subgroup_sizes = 0;
if (screen->info.have_KHR_pipeline_executable_properties) {
struct zink_compute_pipeline_state pipeline_state = {
.dirty = true,
.local_size = { block[0], block[1], block[2] },
.variable_shared_mem = 0, // TODO?,
};
VkPipeline pipeline = zink_get_compute_pipeline(screen, comp, &pipeline_state);
subgroup_sizes = zink_get_subgroup_sizes_for_pipeline(screen, pipeline);
if (util_bitcount(subgroup_sizes) == 1)
return subgroup_sizes;
} else {
uint32_t size = screen->info.props13.minSubgroupSize;
uint32_t max = screen->info.props13.maxSubgroupSize;
for (; size <= max; size <<= 1)
subgroup_sizes |= size;
}
uint32_t invocations = block[0] * block[1] * block[2];
while (subgroup_sizes) {
unsigned size = 1 << u_bit_scan(&subgroup_sizes);
if (invocations <= size * screen->info.props13.maxComputeWorkgroupSubgroups)
return size;
}
// Should never happen and if so, not our fault.
return 0;
}
static void
zink_get_compute_state_info(struct pipe_context *pctx, void *cso, struct pipe_compute_state_object_info *info)
{
@ -2349,7 +2410,7 @@ zink_get_compute_state_info(struct pipe_context *pctx, void *cso, struct pipe_co
info->private_memory = comp->scratch_size;
if (screen->info.props11.subgroupSize) {
info->preferred_simd_size = screen->info.props11.subgroupSize;
info->simd_sizes = info->preferred_simd_size;
info->simd_sizes = screen->base.compute_caps.subgroup_sizes;
} else {
// just guess it
info->preferred_simd_size = 64;
@ -2358,6 +2419,15 @@ zink_get_compute_state_info(struct pipe_context *pctx, void *cso, struct pipe_co
}
}
static uint32_t
zink_get_compute_state_subgroup_size(struct pipe_context *pctx, void *cso,
const uint32_t block[3])
{
struct zink_compute_program *comp = cso;
struct zink_screen *screen = zink_screen(pctx->screen);
return zink_get_subgroup_size_for_block(screen, comp, block);
}
static void
zink_delete_cs_shader_state(struct pipe_context *pctx, void *cso)
{
@ -2616,6 +2686,7 @@ zink_program_init(struct zink_context *ctx)
ctx->base.create_compute_state = zink_create_cs_state;
ctx->base.bind_compute_state = zink_bind_cs_state;
ctx->base.get_compute_state_info = zink_get_compute_state_info;
ctx->base.get_compute_state_subgroup_size = zink_get_compute_state_subgroup_size;
ctx->base.delete_compute_state = zink_delete_cs_shader_state;
if (zink_screen(ctx->base.screen)->info.have_EXT_vertex_input_dynamic_state)

View file

@ -509,6 +509,11 @@ zink_sanitize_optimal_key_mesh(struct zink_shader **shaders, uint32_t val)
k.fs.force_dual_color_blend = false;
return k.val;
}
uint32_t
zink_get_subgroup_size_for_block(struct zink_screen *screen, struct zink_compute_program *comp,
const uint32_t block[3]);
#ifdef __cplusplus
}
#endif

View file

@ -693,7 +693,16 @@ zink_init_compute_caps(struct zink_screen *screen)
caps->max_local_size =
screen->info.props.limits.maxComputeSharedMemorySize;
caps->subgroup_sizes = screen->info.props11.subgroupSize;
if (screen->info.feats13.subgroupSizeControl) {
uint32_t size = screen->info.props13.minSubgroupSize;
uint32_t max = screen->info.props13.maxSubgroupSize;
for (; size <= max; size <<= 1)
caps->subgroup_sizes |= size;
} else {
caps->subgroup_sizes = screen->info.props11.subgroupSize;
}
caps->max_mem_alloc_size = screen->clamp_video_mem;
caps->max_global_size = screen->total_video_mem;
/* no way in vulkan to retrieve this information. */