diff --git a/src/nouveau/compiler/nak_nir.c b/src/nouveau/compiler/nak_nir.c index f23abfc1688..a1ea62bb14d 100644 --- a/src/nouveau/compiler/nak_nir.c +++ b/src/nouveau/compiler/nak_nir.c @@ -178,6 +178,69 @@ lower_bit_size_cb(const nir_instr *instr, void *_data) } } +static nir_def * +nir_udiv_round_up(nir_builder *b, nir_def *n, nir_def *d) +{ + return nir_udiv(b, nir_iadd(b, n, nir_iadd_imm(b, d, -1)), d); +} + +static bool +nak_nir_lower_subgroup_id_intrin(nir_builder *b, nir_intrinsic_instr *intrin, + void *data) +{ + switch (intrin->intrinsic) { + case nir_intrinsic_load_num_subgroups: { + b->cursor = nir_instr_remove(&intrin->instr); + + nir_def *num_subgroups; + if (nak_nir_has_one_subgroup(b->shader)) { + num_subgroups = nir_imm_int(b, 1); + } else { + assert(b->shader->info.cs.derivative_group == DERIVATIVE_GROUP_NONE); + + nir_def *workgroup_size = nir_load_workgroup_size(b); + workgroup_size = + nir_imul(b, nir_imul(b, nir_channel(b, workgroup_size, 0), + nir_channel(b, workgroup_size, 1)), + nir_channel(b, workgroup_size, 2)); + nir_def *subgroup_size = nir_load_subgroup_size(b); + num_subgroups = nir_udiv_round_up(b, workgroup_size, subgroup_size); + } + nir_def_rewrite_uses(&intrin->def, num_subgroups); + + return true; + } + case nir_intrinsic_load_subgroup_id: { + b->cursor = nir_instr_remove(&intrin->instr); + + nir_def *subgroup_id; + if (nak_nir_has_one_subgroup(b->shader)) { + subgroup_id = nir_imm_int(b, 0); + } else { + assert(b->shader->info.cs.derivative_group == DERIVATIVE_GROUP_NONE); + + nir_def *invocation_index = nir_load_local_invocation_index(b); + nir_def *subgroup_size = nir_load_subgroup_size(b); + subgroup_id = nir_udiv(b, invocation_index, subgroup_size); + } + nir_def_rewrite_uses(&intrin->def, subgroup_id); + + return true; + } + default: + return false; + } +} + +static bool +nak_nir_lower_subgroup_id(nir_shader *nir) +{ + return nir_shader_intrinsics_pass(nir, nak_nir_lower_subgroup_id_intrin, + nir_metadata_block_index | + nir_metadata_dominance, + NULL); +} + void nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak) { @@ -214,6 +277,7 @@ nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak) OPT(nir, nir_lower_load_const_to_scalar); OPT(nir, nir_lower_var_copies); OPT(nir, nir_lower_system_values); + OPT(nir, nak_nir_lower_subgroup_id); OPT(nir, nir_lower_compute_system_values, NULL); const nir_lower_subgroups_options subgroups_options = {