nak: Lower subgroup_id and num_subgroups

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-10-23 16:42:15 -05:00 committed by Marge Bot
parent 42a305416a
commit 143d88dcc3

View file

@ -178,6 +178,69 @@ lower_bit_size_cb(const nir_instr *instr, void *_data)
}
}
static nir_def *
nir_udiv_round_up(nir_builder *b, nir_def *n, nir_def *d)
{
return nir_udiv(b, nir_iadd(b, n, nir_iadd_imm(b, d, -1)), d);
}
static bool
nak_nir_lower_subgroup_id_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
void *data)
{
switch (intrin->intrinsic) {
case nir_intrinsic_load_num_subgroups: {
b->cursor = nir_instr_remove(&intrin->instr);
nir_def *num_subgroups;
if (nak_nir_has_one_subgroup(b->shader)) {
num_subgroups = nir_imm_int(b, 1);
} else {
assert(b->shader->info.cs.derivative_group == DERIVATIVE_GROUP_NONE);
nir_def *workgroup_size = nir_load_workgroup_size(b);
workgroup_size =
nir_imul(b, nir_imul(b, nir_channel(b, workgroup_size, 0),
nir_channel(b, workgroup_size, 1)),
nir_channel(b, workgroup_size, 2));
nir_def *subgroup_size = nir_load_subgroup_size(b);
num_subgroups = nir_udiv_round_up(b, workgroup_size, subgroup_size);
}
nir_def_rewrite_uses(&intrin->def, num_subgroups);
return true;
}
case nir_intrinsic_load_subgroup_id: {
b->cursor = nir_instr_remove(&intrin->instr);
nir_def *subgroup_id;
if (nak_nir_has_one_subgroup(b->shader)) {
subgroup_id = nir_imm_int(b, 0);
} else {
assert(b->shader->info.cs.derivative_group == DERIVATIVE_GROUP_NONE);
nir_def *invocation_index = nir_load_local_invocation_index(b);
nir_def *subgroup_size = nir_load_subgroup_size(b);
subgroup_id = nir_udiv(b, invocation_index, subgroup_size);
}
nir_def_rewrite_uses(&intrin->def, subgroup_id);
return true;
}
default:
return false;
}
}
static bool
nak_nir_lower_subgroup_id(nir_shader *nir)
{
return nir_shader_intrinsics_pass(nir, nak_nir_lower_subgroup_id_intrin,
nir_metadata_block_index |
nir_metadata_dominance,
NULL);
}
void
nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak)
{
@ -214,6 +277,7 @@ nak_preprocess_nir(nir_shader *nir, const struct nak_compiler *nak)
OPT(nir, nir_lower_load_const_to_scalar);
OPT(nir, nir_lower_var_copies);
OPT(nir, nir_lower_system_values);
OPT(nir, nak_nir_lower_subgroup_id);
OPT(nir, nir_lower_compute_system_values, NULL);
const nir_lower_subgroups_options subgroups_options = {