From 44ba856089f727f75d3783c73ce2188eb9842dcf Mon Sep 17 00:00:00 2001 From: Samuel Pitoiset Date: Fri, 3 Jan 2025 01:54:00 -0800 Subject: [PATCH] ac/nir: fix lowering subgroup ID for compute shaders on GFX12 This is lowered in backend compilers (LLVM or ACO) because it needs to access ttmp registers which aren't exposed to NIR. Signed-off-by: Samuel Pitoiset Part-of: --- src/amd/common/ac_nir.c | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/amd/common/ac_nir.c b/src/amd/common/ac_nir.c index f66cd301bc5..e8b749ab0dd 100644 --- a/src/amd/common/ac_nir.c +++ b/src/amd/common/ac_nir.c @@ -159,10 +159,7 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b) if (s->workgroup_size <= s->wave_size) { return nir_imm_int(b, 0); } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) { - if (s->gfx_level >= GFX12) - return false; - - assert(s->args->tg_size.used); + assert(s->gfx_level < GFX12 && s->args->tg_size.used); if (s->gfx_level >= GFX10_3) { return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5); @@ -198,6 +195,8 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state) switch (intrin->intrinsic) { case nir_intrinsic_load_subgroup_id: + if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) + return false; /* Lowered in backend compilers. */ replacement = load_subgroup_id_lowered(s, b); break; case nir_intrinsic_load_num_subgroups: { @@ -556,8 +555,16 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state) nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0); replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64); } else { + nir_def *subgroup_id; + + if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) { + subgroup_id = nir_load_subgroup_id(b); + } else { + subgroup_id = load_subgroup_id_lowered(s, b); + } + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), - nir_imul_imm(b, load_subgroup_id_lowered(s, b), s->wave_size)); + nir_imul_imm(b, subgroup_id, s->wave_size)); } break; case nir_intrinsic_load_subgroup_invocation: