radv: determine subgroup/wave size early

This means we can actually implement varying subgroup size correctly.
It also means that we implement the implicit SPIR-V 1.6 full subgroups
requirement in compute shaders with cswave32/rtwave32.

In the future it will also allow more optimizations that use the subgroup size.

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>

The only somewhat complex case here is GFX10 geometry shaders, if gewave32 is
used. We then only know the subgroup size when is_ngg is decided, as legacy
GS doesn't support wave32.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37294>
This commit is contained in:
Georg Lehmann 2025-09-10 17:15:05 +02:00 committed by Marge Bot
parent 76a502d75a
commit a2d3cbac2a
4 changed files with 98 additions and 62 deletions

View file

@ -1754,6 +1754,9 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = pdev->rt_wave_size == 64 ? 8 : 4;
b.shader->info.api_subgroup_size = pdev->rt_wave_size;
b.shader->info.max_subgroup_size = pdev->rt_wave_size;
b.shader->info.min_subgroup_size = pdev->rt_wave_size;
b.shader->info.shared_size = pdev->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);

View file

@ -2265,6 +2265,15 @@ radv_fill_shader_info_ngg(struct radv_device *device, struct radv_shader_stage *
stages[MESA_SHADER_GEOMETRY].info.is_ngg = false;
}
}
/* Now that we know if ngg is used for geometry shaders, determine the subgroup size. */
if (stages[MESA_SHADER_GEOMETRY].nir) {
unsigned wave_size = stages[MESA_SHADER_GEOMETRY].info.is_ngg
? stages[MESA_SHADER_GEOMETRY].nir->info.min_subgroup_size
: stages[MESA_SHADER_GEOMETRY].nir->info.max_subgroup_size;
stages[MESA_SHADER_GEOMETRY].nir->info.max_subgroup_size = wave_size;
stages[MESA_SHADER_GEOMETRY].nir->info.min_subgroup_size = wave_size;
}
}
static bool

View file

@ -338,23 +338,87 @@ radv_compiler_debug(void *private_data, enum aco_compiler_debug_level level, con
vk_debug_report(&instance->vk, vk_flags[level] | VK_DEBUG_REPORT_DEBUG_BIT_EXT, NULL, 0, 0, "radv", message);
}
static void
radv_shader_choose_subgroup_size(struct radv_device *device, nir_shader *nir,
const struct radv_shader_stage_key *stage_key, unsigned spirv_version)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
VkPipelineShaderStageRequiredSubgroupSizeCreateInfo rss_info = {
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
.requiredSubgroupSize = stage_key->subgroup_required_size * 32,
};
vk_set_subgroup_size(&device->vk, nir, spirv_version, rss_info.requiredSubgroupSize ? &rss_info : NULL,
stage_key->subgroup_allow_varying, stage_key->subgroup_require_full);
nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
unsigned wave_size;
if (nir->info.min_subgroup_size == nir->info.max_subgroup_size) {
wave_size = nir->info.min_subgroup_size;
} else if (mesa_shader_stage_uses_workgroup(nir->info.stage)) {
const unsigned local_size =
nir->info.workgroup_size[0] * nir->info.workgroup_size[1] * nir->info.workgroup_size[2];
unsigned default_wave_size;
if (nir->info.ray_queries)
default_wave_size = pdev->rt_wave_size;
else if (nir->info.stage == MESA_SHADER_MESH)
default_wave_size = pdev->ge_wave_size;
else
default_wave_size = pdev->cs_wave_size;
/* Games don't always request full subgroups when they should, which can cause bugs if cswave32
* is enabled. Furthermore, if cooperative matrices or subgroup info are used, we can't transparently change
* the subgroup size.
*/
const bool require_full_subgroups =
(default_wave_size == 32 &&
(nir->info.uses_wide_subgroup_intrinsics ||
(nir->info.stage == MESA_SHADER_COMPUTE && nir->info.cs.has_cooperative_matrix)) &&
local_size % RADV_SUBGROUP_SIZE == 0);
/* Use wave32 for small workgroups. */
if (local_size <= 32)
wave_size = 32;
else if (require_full_subgroups)
wave_size = 64;
else
wave_size = default_wave_size;
} else if (nir->info.stage == MESA_SHADER_GEOMETRY &&
(pdev->info.gfx_level >= GFX10 && pdev->info.gfx_level <= GFX10_3)) {
/* Legacy GS doesn't support wave32. */
wave_size = 64;
} else if (nir->info.stage == MESA_SHADER_FRAGMENT) {
wave_size = pdev->ps_wave_size;
} else if (mesa_shader_stage_is_rt(nir->info.stage)) {
wave_size = pdev->rt_wave_size;
} else {
wave_size = pdev->ge_wave_size;
}
if (nir->info.api_subgroup_size == 0) {
/* Report real wave_size for allow_varying. */
nir->info.api_subgroup_size = wave_size;
}
nir->info.max_subgroup_size = wave_size;
/* We might still decide to use ngg later. */
if (nir->info.stage == MESA_SHADER_GEOMETRY)
nir->info.min_subgroup_size = pdev->ge_wave_size;
else
nir->info.min_subgroup_size = wave_size;
}
nir_shader *
radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_stage *stage,
const struct radv_spirv_to_nir_options *options, bool is_internal)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
const struct radv_instance *instance = radv_physical_device_instance(pdev);
unsigned subgroup_size = 64, ballot_bit_size = 64;
const unsigned required_subgroup_size = stage->key.subgroup_required_size * 32;
if (required_subgroup_size) {
/* Only compute/mesh/task shaders currently support requiring a
* specific subgroup size.
*/
assert(stage->stage >= MESA_SHADER_COMPUTE);
subgroup_size = required_subgroup_size;
ballot_bit_size = required_subgroup_size;
}
nir_shader *nir;
if (stage->internal_nir) {
@ -366,6 +430,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
nir_validate_shader(nir, "in internal shader");
assert(exec_list_length(&nir->functions) == 1);
radv_shader_choose_subgroup_size(device, nir, &stage->key, 0);
} else {
uint32_t *spirv = (uint32_t *)stage->spirv.data;
assert(stage->spirv.size % 4 == 0);
@ -456,6 +521,8 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
*/
NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
radv_shader_choose_subgroup_size(device, nir, &stage->key, vk_spirv_version(spirv, stage->spirv.size));
progress = false;
NIR_PASS(progress, nir, nir_lower_cooperative_matrix_flexible_dimensions, 16, 16, 16);
if (progress) {
@ -463,7 +530,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
}
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size);
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, nir->info.max_subgroup_size);
/* Split member structs. We do this before lower_io_to_temporaries so that
* it doesn't lower system values to temporaries by accident.
@ -587,8 +654,8 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
NIR_PASS(_, nir, nir_lower_subgroups,
&(struct nir_lower_subgroups_options){
.subgroup_size = subgroup_size,
.ballot_bit_size = ballot_bit_size,
.subgroup_size = nir->info.api_subgroup_size,
.ballot_bit_size = nir->info.api_subgroup_size,
.ballot_components = 1,
.lower_to_scalar = 1,
.lower_subgroup_masks = 1,

View file

@ -481,27 +481,6 @@ radv_set_vs_output_param(struct radv_device *device, const struct nir_shader *ni
outinfo->prim_param_exports = total_param_exports - outinfo->param_exports;
}
static uint8_t
radv_get_wave_size(struct radv_device *device, mesa_shader_stage stage, const struct radv_shader_info *info,
const struct radv_shader_stage_key *stage_key)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
if (stage_key->subgroup_required_size)
return stage_key->subgroup_required_size * 32;
if (stage == MESA_SHADER_GEOMETRY && !info->is_ngg)
return 64;
else if (stage == MESA_SHADER_COMPUTE || stage == MESA_SHADER_TASK)
return info->wave_size;
else if (stage == MESA_SHADER_FRAGMENT)
return pdev->ps_wave_size;
else if (mesa_shader_stage_is_rt(stage))
return pdev->rt_wave_size;
else
return pdev->ge_wave_size;
}
static uint32_t
radv_compute_esgs_itemsize(const struct radv_device *device, uint32_t num_varyings)
{
@ -920,32 +899,6 @@ gather_shader_info_cs(struct radv_device *device, const nir_shader *nir, const s
struct radv_shader_info *info)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
unsigned default_wave_size = pdev->cs_wave_size;
if (info->cs.uses_rt)
default_wave_size = pdev->rt_wave_size;
unsigned local_size = nir->info.workgroup_size[0] * nir->info.workgroup_size[1] * nir->info.workgroup_size[2];
/* Games don't always request full subgroups when they should, which can cause bugs if cswave32
* is enabled. Furthermore, if cooperative matrices or subgroup info are used, we can't transparently change
* the subgroup size.
*/
const bool require_full_subgroups =
stage_key->subgroup_require_full || nir->info.cs.has_cooperative_matrix ||
(default_wave_size == 32 && nir->info.uses_wide_subgroup_intrinsics && local_size % RADV_SUBGROUP_SIZE == 0);
const unsigned required_subgroup_size = stage_key->subgroup_required_size * 32;
if (required_subgroup_size) {
info->wave_size = required_subgroup_size;
} else if (require_full_subgroups) {
info->wave_size = RADV_SUBGROUP_SIZE;
} else if (pdev->info.gfx_level >= GFX10 && local_size <= 32) {
/* Use wave32 for small workgroups. */
info->wave_size = 32;
} else {
info->wave_size = default_wave_size;
}
if (pdev->info.has_cs_regalloc_hang_bug) {
info->cs.regalloc_hang_bug = info->cs.block_size[0] * info->cs.block_size[1] * info->cs.block_size[2] > 256;
@ -1180,7 +1133,11 @@ radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *n
break;
}
info->wave_size = radv_get_wave_size(device, nir->info.stage, info, stage_key);
info->wave_size = nir->info.min_subgroup_size;
assert(info->wave_size == nir->info.max_subgroup_size);
assert(info->wave_size == 32 || info->wave_size == 64);
assert(pdev->info.gfx_level >= GFX10 || info->wave_size == 64);
assert(nir->info.stage != MESA_SHADER_GEOMETRY || info->is_ngg || info->wave_size == 64);
switch (nir->info.stage) {
case MESA_SHADER_COMPUTE: