diff --git a/src/amd/common/ac_nir.c b/src/amd/common/ac_nir.c index f66cd301bc5..e8b749ab0dd 100644 --- a/src/amd/common/ac_nir.c +++ b/src/amd/common/ac_nir.c @@ -159,10 +159,7 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b) if (s->workgroup_size <= s->wave_size) { return nir_imm_int(b, 0); } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) { - if (s->gfx_level >= GFX12) - return false; - - assert(s->args->tg_size.used); + assert(s->gfx_level < GFX12 && s->args->tg_size.used); if (s->gfx_level >= GFX10_3) { return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5); @@ -198,6 +195,8 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state) switch (intrin->intrinsic) { case nir_intrinsic_load_subgroup_id: + if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) + return false; /* Lowered in backend compilers. */ replacement = load_subgroup_id_lowered(s, b); break; case nir_intrinsic_load_num_subgroups: { @@ -556,8 +555,16 @@ lower_intrinsic_to_arg(nir_builder *b, nir_instr *instr, void *state) nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0); replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64); } else { + nir_def *subgroup_id; + + if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER) { + subgroup_id = nir_load_subgroup_id(b); + } else { + subgroup_id = load_subgroup_id_lowered(s, b); + } + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), - nir_imul_imm(b, load_subgroup_id_lowered(s, b), s->wave_size)); + nir_imul_imm(b, subgroup_id, s->wave_size)); } break; case nir_intrinsic_load_subgroup_invocation: