diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 5e06f163b81..c6b788e28b7 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5408,6 +5408,7 @@ typedef struct nir_lower_compute_system_values_options { bool lower_local_invocation_index:1; bool lower_cs_local_id_to_index:1; bool lower_workgroup_id_to_index:1; + uint16_t num_workgroups[3]; /* Compile-time-known dispatch sizes, or 0 if unknown. */ } nir_lower_compute_system_values_options; bool nir_lower_compute_system_values(nir_shader *shader, diff --git a/src/compiler/nir/nir_lower_system_values.c b/src/compiler/nir/nir_lower_system_values.c index 5253efd0098..9904cc80543 100644 --- a/src/compiler/nir/nir_lower_system_values.c +++ b/src/compiler/nir/nir_lower_system_values.c @@ -672,10 +672,18 @@ lower_compute_system_value_instr(nir_builder *b, if (options && options->has_base_workgroup_id) return nir_iadd(b, nir_u2uN(b, nir_load_workgroup_id_zero_base(b), bit_size), nir_load_base_workgroup_id(b, bit_size)); - else if (options && options->lower_workgroup_id_to_index) - return lower_id_to_index_no_umod(b, nir_load_workgroup_index(b), + else if (options && options->lower_workgroup_id_to_index) { + nir_ssa_def *wg_idx = nir_load_workgroup_index(b); + + nir_ssa_def *val = + try_lower_id_to_index_1d(b, wg_idx, options->num_workgroups); + if (val) + return val; + + return lower_id_to_index_no_umod(b, wg_idx, nir_load_num_workgroups(b, bit_size), bit_size); + } return NULL;