From 6d9f56396088e5b37c05f40cd0a58c7efc48521d Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Wed, 19 Nov 2025 12:46:10 -0500 Subject: [PATCH] spirv: Assume variable workgroup size unless it's set This fixes an issue a bunch of different components were all working around themselves where sometimes we don't have a workgroup size but workgroup_size_variable is false. This also fixes asahi_clc, which didn't have the workaround and was assuming zero (but not variable!) workgroup sizes everywhere. LoLed-by: Alyssa Rosenzweig Acked-by: Mel Henning Part-of: --- src/compiler/spirv/spirv_to_nir.c | 25 +++++++++++-------- src/gallium/frontends/rusticl/core/kernel.rs | 1 - .../frontends/rusticl/mesa/compiler/nir.rs | 9 ------- src/microsoft/clc/clc_compiler.c | 1 - src/panfrost/clc/pan_compile.c | 6 ----- 5 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 9f78561943d..64ad0311aec 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -5490,6 +5490,7 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, case SpvExecutionModeLocalSize: if (mesa_shader_stage_uses_workgroup(b->shader->info.stage)) { + b->shader->info.workgroup_size_variable = false; b->shader->info.workgroup_size[0] = mode->operands[0]; b->shader->info.workgroup_size[1] = mode->operands[1]; b->shader->info.workgroup_size[2] = mode->operands[2]; @@ -5826,6 +5827,7 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin switch (mode->exec_mode) { case SpvExecutionModeLocalSizeId: if (mesa_shader_stage_uses_workgroup(b->shader->info.stage)) { + b->shader->info.workgroup_size_variable = false; b->shader->info.workgroup_size[0] = vtn_constant_uint(b, mode->operands[0]); b->shader->info.workgroup_size[1] = vtn_constant_uint(b, mode->operands[1]); b->shader->info.workgroup_size[2] = vtn_constant_uint(b, mode->operands[2]); @@ -7300,6 +7302,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count, b->shader = nir_shader_create(b, stage, nir_options); b->shader->info.float_controls_execution_mode = options->float_controls_execution_mode; + b->shader->info.workgroup_size_variable = true; b->shader->info.cs.shader_index = options->shader_index; b->shader->has_debug_info = options->debug_info; _mesa_blake3_compute(words, word_count * sizeof(uint32_t), b->shader->info.source_blake3); @@ -7429,21 +7432,23 @@ spirv_to_nir(const uint32_t *words, size_t word_count, /* Parse execution modes that depend on IDs. Must happen after we have * constants parsed. */ - if (!options->create_library) + if (!options->create_library) { vtn_foreach_execution_mode(b, b->entry_point, vtn_handle_execution_mode_id, NULL); - if (b->workgroup_size_builtin) { - vtn_assert(mesa_shader_stage_uses_workgroup(stage)); - vtn_assert(b->workgroup_size_builtin->type->type == - glsl_vector_type(GLSL_TYPE_UINT, 3)); + if (b->workgroup_size_builtin) { + vtn_assert(mesa_shader_stage_uses_workgroup(stage)); + vtn_assert(b->workgroup_size_builtin->type->type == + glsl_vector_type(GLSL_TYPE_UINT, 3)); - nir_const_value *const_size = - b->workgroup_size_builtin->constant->values; + nir_const_value *const_size = + b->workgroup_size_builtin->constant->values; - b->shader->info.workgroup_size[0] = const_size[0].u32; - b->shader->info.workgroup_size[1] = const_size[1].u32; - b->shader->info.workgroup_size[2] = const_size[2].u32; + b->shader->info.workgroup_size_variable = false; + b->shader->info.workgroup_size[0] = const_size[0].u32; + b->shader->info.workgroup_size[1] = const_size[1].u32; + b->shader->info.workgroup_size[2] = const_size[2].u32; + } } /* Set types on all vtn_values */ diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 49a88db2925..5132f5b8392 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -706,7 +706,6 @@ fn compile_nir_to_args( nir.set_fp_rounding_mode_rtne(); nir_pass!(nir, nir_scale_fdiv); - nir.set_workgroup_size_variable_if_zero(); nir.structurize(); nir_pass!( nir, diff --git a/src/gallium/frontends/rusticl/mesa/compiler/nir.rs b/src/gallium/frontends/rusticl/mesa/compiler/nir.rs index 4bbf33bc1ed..6fa07948d4d 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/nir.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/nir.rs @@ -369,15 +369,6 @@ impl NirShader { unsafe { (*self.nir.as_ptr()).info.num_subgroups } } - pub fn set_workgroup_size_variable_if_zero(&mut self) { - let nir = self.nir.as_ptr(); - unsafe { - (*nir) - .info - .set_workgroup_size_variable((*nir).info.workgroup_size[0] == 0); - } - } - pub fn set_workgroup_size(&mut self, size: [u16; 3]) { let nir = unsafe { self.nir.as_mut() }; nir.info.set_workgroup_size_variable(false); diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c index 58cb72daaba..c90ab921a0f 100644 --- a/src/microsoft/clc/clc_compiler.c +++ b/src/microsoft/clc/clc_compiler.c @@ -801,7 +801,6 @@ clc_spirv_to_dxil(struct clc_libclc *lib, clc_error(logger, "spirv_to_nir() failed"); goto err_free_dxil; } - nir->info.workgroup_size_variable = true; NIR_PASS(_, nir, nir_lower_goto_ifs); NIR_PASS(_, nir, nir_opt_dead_cf); diff --git a/src/panfrost/clc/pan_compile.c b/src/panfrost/clc/pan_compile.c index 401ccfd60cd..5e30968416e 100644 --- a/src/panfrost/clc/pan_compile.c +++ b/src/panfrost/clc/pan_compile.c @@ -100,12 +100,6 @@ compile(void *memctx, const uint32_t *spirv, size_t spirv_size, unsigned arch) nir_shader *nir = spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL, "library", &spirv_options, nir_options); - /* Workgroup size may be different between different entrypoints, so we - * mark it as variable to prevent it from being lowered to a constant while - * we are still processing all entrypoints together. This is tempoary, - * nir_precompiled_build_variant will set the fixed workgroup size for each - * entrypoint and set workgroup_size_variable back to false. */ - nir->info.workgroup_size_variable = true; nir_validate_shader(nir, "after spirv_to_nir"); nir_validate_ssa_dominance(nir, "after spirv_to_nir"); ralloc_steal(memctx, nir);