diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index 27a4ce40780..364ab3a68f6 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -2074,3 +2074,76 @@ dxil_nir_lower_sample_pos(nir_shader *s) { return nir_shader_lower_instructions(s, is_sample_pos, lower_sample_pos, NULL); } + +static bool +lower_subgroup_id(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (intr->intrinsic != nir_intrinsic_load_subgroup_id) + return false; + + nir_ssa_def **subgroup_id = (nir_ssa_def **)data; + if (*subgroup_id == NULL) { + nir_variable *subgroup_id_counter = nir_variable_create(b->shader, nir_var_mem_shared, glsl_uint_type(), "dxil_SubgroupID_counter"); + nir_variable *subgroup_id_local = nir_local_variable_create(b->impl, glsl_uint_type(), "dxil_SubgroupID_local"); + b->cursor = nir_before_block(nir_start_block(b->impl)); + nir_store_var(b, subgroup_id_local, nir_imm_int(b, 0), 1); + + nir_deref_instr *counter_deref = nir_build_deref_var(b, subgroup_id_counter); + nir_ssa_def *tid = nir_load_local_invocation_index(b); + nir_if *nif = nir_push_if(b, nir_ieq_imm(b, tid, 0)); + nir_store_deref(b, counter_deref, nir_imm_int(b, 0), 1); + nir_pop_if(b, nif); + + nir_scoped_memory_barrier(b, NIR_SCOPE_WORKGROUP, NIR_MEMORY_ACQ_REL, nir_var_mem_shared); + + nif = nir_push_if(b, nir_elect(b, 1)); + nir_ssa_def *subgroup_id_first_thread = nir_deref_atomic_add(b, 32, &counter_deref->dest.ssa, nir_imm_int(b, 1)); + nir_store_var(b, subgroup_id_local, subgroup_id_first_thread, 1); + nir_pop_if(b, nif); + + nir_ssa_def *subgroup_id_loaded = nir_load_var(b, subgroup_id_local); + *subgroup_id = nir_read_first_invocation(b, subgroup_id_loaded); + } + nir_ssa_def_rewrite_uses(&intr->dest.ssa, *subgroup_id); + return true; +} + +bool +dxil_nir_lower_subgroup_id(nir_shader *s) +{ + nir_ssa_def *subgroup_id = NULL; + return nir_shader_instructions_pass(s, lower_subgroup_id, nir_metadata_none, &subgroup_id); +} + +static bool +lower_num_subgroups(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + if (intr->intrinsic != nir_intrinsic_load_num_subgroups) + return false; + + b->cursor = nir_before_instr(instr); + nir_ssa_def *subgroup_size = nir_load_subgroup_size(b); + nir_ssa_def *size_minus_one = nir_iadd_imm(b, subgroup_size, -1); + nir_ssa_def *workgroup_size_vec = nir_load_workgroup_size(b); + nir_ssa_def *workgroup_size = nir_imul(b, nir_channel(b, workgroup_size_vec, 0), + nir_imul(b, nir_channel(b, workgroup_size_vec, 1), + nir_channel(b, workgroup_size_vec, 2))); + nir_ssa_def *ret = nir_idiv(b, nir_iadd(b, workgroup_size, size_minus_one), subgroup_size); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + return true; +} + +bool +dxil_nir_lower_num_subgroups(nir_shader *s) +{ + return nir_shader_instructions_pass(s, lower_num_subgroups, + nir_metadata_block_index | + nir_metadata_dominance | + nir_metadata_loop_analysis, NULL); +} diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index 57baba34d65..adeea6d3fad 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -76,6 +76,8 @@ bool dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mas bool dxil_nir_lower_discard_and_terminate(nir_shader* s); bool dxil_nir_ensure_position_writes(nir_shader *s); bool dxil_nir_lower_sample_pos(nir_shader *s); +bool dxil_nir_lower_subgroup_id(nir_shader *s); +bool dxil_nir_lower_num_subgroups(nir_shader *s); #ifdef __cplusplus }