microsoft/compiler: Add lowering passes for basic subgroup vars

DXIL doesn't have a "subgroup ID" or "num subgroups" construct,
so add lowering to construct them. Subgroup ID is done using
once-per-subgroup atomics on a workgroup-shared variable, and
then broadcasting that (using read_first_invocation) to the other
threads. Num subgroups is just a division with the workgroup size.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20777>
This commit is contained in:
Jesse Natalie 2023-01-18 14:13:41 -08:00 committed by Marge Bot
parent a422df4b61
commit 2f8a8b5949
2 changed files with 75 additions and 0 deletions

View file

@ -2074,3 +2074,76 @@ dxil_nir_lower_sample_pos(nir_shader *s)
{
return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL);
}
static bool
lower_subgroup_id(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (intr->intrinsic != nir_intrinsic_load_subgroup_id)
return false;
nir_ssa_def **subgroup_id = (nir_ssa_def **)data;
if (*subgroup_id == NULL) {
nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter");
nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local");
b->cursor = nir_before_block(nir_start_block(b->impl));
nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1);
nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter);
nir_ssa_def *tid = nir_load_local_invocation_index(b);
nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0));
nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1);
nir_pop_if(b, nif);
nir_scoped_memory_barrier(b, NIR_SCOPE_WORKGROUP, NIR_MEMORY_ACQ_REL, nir_var_mem_shared);
nif = nir_push_if(b, nir_elect(b, 1));
nir_ssa_def *subgroup_id_first_thread = nir_deref_atomic_add(b, 32, &counter_deref->dest.ssa, nir_imm_int(b, 1));
nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1);
nir_pop_if(b, nif);
nir_ssa_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local);
*subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded);
}
nir_ssa_def_rewrite_uses(&intr->dest.ssa, *subgroup_id);
return true;
}
bool
dxil_nir_lower_subgroup_id(nir_shader *s)
{
nir_ssa_def *subgroup_id = NULL;
return nir_shader_instructions_pass(s, lower_subgroup_id, nir_metadata_none, &subgroup_id);
}
static bool
lower_num_subgroups(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (intr->intrinsic != nir_intrinsic_load_num_subgroups)
return false;
b->cursor = nir_before_instr(instr);
nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
nir_ssa_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1);
nir_ssa_def *workgroup_size_vec = nir_load_workgroup_size(b);
nir_ssa_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0),
nir_imul(b, nir_channel(b, workgroup_size_vec, 1),
nir_channel(b, workgroup_size_vec, 2)));
nir_ssa_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
return true;
}
bool
dxil_nir_lower_num_subgroups(nir_shader *s)
{
return nir_shader_instructions_pass(s, lower_num_subgroups,
nir_metadata_block_index |
nir_metadata_dominance |
nir_metadata_loop_analysis, NULL);
}

View file

@ -76,6 +76,8 @@ bool dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mas
bool dxil_nir_lower_discard_and_terminate(nir_shader* s);
bool dxil_nir_ensure_position_writes(nir_shader *s);
bool dxil_nir_lower_sample_pos(nir_shader *s);
bool dxil_nir_lower_subgroup_id(nir_shader *s);
bool dxil_nir_lower_num_subgroups(nir_shader *s);
#ifdef __cplusplus
}