ac/nir: add ac_nir_lower_intrinsics_to_args_options structure

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39638>
This commit is contained in:
Marek Olšák 2026-01-31 15:38:11 -05:00 committed by Marge Bot
parent a9e47751d2
commit 1e11e83d1c
5 changed files with 85 additions and 68 deletions

View file

@ -89,10 +89,17 @@ ac_nir_load_smem(nir_builder *b, unsigned num_components, nir_def *addr, nir_def
bool ac_nir_lower_sin_cos(nir_shader *shader);
bool ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
unsigned wave_size, unsigned workgroup_size, bool use_llvm,
const struct ac_shader_args *ac_args);
typedef struct {
enum amd_gfx_level gfx_level;
bool has_ls_vgpr_init_bug;
const enum ac_hw_stage hw_stage;
unsigned wave_size;
unsigned workgroup_size;
bool use_llvm;
} ac_nir_lower_intrinsics_to_args_options;
bool ac_nir_lower_intrinsics_to_args(nir_shader *shader, const struct ac_shader_args *ac_args,
const ac_nir_lower_intrinsics_to_args_options *options);
nir_xfb_info *ac_nir_get_sorted_xfb_info(const nir_shader *nir);

View file

@ -12,12 +12,7 @@
typedef struct {
const struct ac_shader_args *const args;
const enum amd_gfx_level gfx_level;
bool use_llvm;
bool has_ls_vgpr_init_bug;
unsigned wave_size;
unsigned workgroup_size;
const enum ac_hw_stage hw_stage;
const ac_nir_lower_intrinsics_to_args_options *options;
nir_def *vertex_id;
nir_def *instance_id;
@ -36,8 +31,8 @@ preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct a
nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound);
/* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */
if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) &&
s->has_ls_vgpr_init_bug) {
if ((s->options->hw_stage == AC_HW_LOCAL_SHADER || s->options->hw_stage == AC_HW_HULL_SHADER) &&
s->options->has_ls_vgpr_init_bug) {
nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8);
nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0);
value = nir_bcsel(&start_b, hs_empty,
@ -50,14 +45,14 @@ preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct a
static nir_def *
load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
{
if (s->workgroup_size <= s->wave_size) {
if (s->options->workgroup_size <= s->options->wave_size) {
return nir_imm_int(b, 0);
} else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->gfx_level >= GFX12) {
assert(!s->use_llvm);
} else if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->options->gfx_level >= GFX12) {
assert(!s->options->use_llvm);
nir_def *ttmp8 = nir_load_ttmp_register_amd(b, .base = 8);
return nir_ubfe_imm(b, ttmp8, 25, 5);
} else if (s->gfx_level >= GFX10_3) {
} else if (s->options->gfx_level >= GFX10_3) {
assert(s->args->tg_size.used);
return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5);
} else {
@ -68,8 +63,8 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
assert(s->args->tg_size.used);
return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6);
}
} else if (s->hw_stage == AC_HW_HULL_SHADER) {
if (s->gfx_level >= GFX11) {
} else if (s->options->hw_stage == AC_HW_HULL_SHADER) {
if (s->options->gfx_level >= GFX11) {
assert(s->args->tcs_wave_id.used);
return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3);
} else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
@ -94,12 +89,12 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b)
sgpr_local_invocation_index = nir_iadd(b, sgpr_patch_start, sgpr_invocation_id);
}
}
return nir_ushr_imm(b, sgpr_local_invocation_index, util_logbase2(s->wave_size));
return nir_ushr_imm(b, sgpr_local_invocation_index, util_logbase2(s->options->wave_size));
} else {
UNREACHABLE("unimplemented for LS");
}
} else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
} else if (s->options->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->options->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
assert(s->args->merged_wave_info.used);
return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4);
} else {
@ -117,17 +112,17 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
switch (intrin->intrinsic) {
case nir_intrinsic_load_subgroup_id:
/* LLVM uses an intrinsic to get this on gfx12. */
if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->use_llvm)
if (s->options->gfx_level >= GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->use_llvm)
return false;
replacement = load_subgroup_id_lowered(s, b);
break;
case nir_intrinsic_load_num_subgroups: {
if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) {
assert(s->args->tg_size.used);
replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6);
} else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
} else if (s->options->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER ||
s->options->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) {
assert(s->args->merged_wave_info.used);
replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4);
} else {
@ -141,19 +136,19 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
/* This lowering is only valid with fast_launch = 2, otherwise we assume that
* lower_workgroup_id_to_index removed any uses of the workgroup id by this point.
*/
assert(s->gfx_level >= GFX11);
assert(s->options->gfx_level >= GFX11);
nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset);
nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset);
replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)),
nir_extract_u16(b, xy, nir_imm_int(b, 1)),
nir_extract_u16(b, z, nir_imm_int(b, 1)));
} else if (s->hw_stage == AC_HW_COMPUTE_SHADER) {
} else if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) {
nir_def *undef = nir_undef(b, 1, 32);
nir_def *ids[3] = {undef, undef, undef};
if (s->gfx_level >= GFX12) {
if (s->options->gfx_level >= GFX12) {
/* LLVM uses intrinsics to get workgroup IDs on gfx12. */
if (s->use_llvm)
if (s->options->use_llvm)
return false;
if (s->args->workgroup_ids[0].used)
@ -318,9 +313,9 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5);
} else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) {
if (s->gfx_level >= GFX12) {
if (s->options->gfx_level >= GFX12) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5);
} else if (s->gfx_level >= GFX10) {
} else if (s->options->gfx_level >= GFX10) {
replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5);
} else {
replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31);
@ -359,7 +354,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
break;
case nir_intrinsic_load_layer_id:
replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary,
16, s->gfx_level >= GFX12 ? 14 : 13);
16, s->options->gfx_level >= GFX12 ? 14 : 13);
break;
case nir_intrinsic_load_barycentric_optimize_amd: {
nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask);
@ -467,7 +462,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
replacement = s->tes_patch_id ? s->tes_patch_id :
ac_nir_load_arg(b, s->args, s->args->tes_patch_id);
} else if (b->shader->info.stage == MESA_SHADER_VERTEX) {
if (s->hw_stage == AC_HW_VERTEX_SHADER)
if (s->options->hw_stage == AC_HW_VERTEX_SHADER)
replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */
else
replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */
@ -490,38 +485,38 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
}
case nir_intrinsic_load_local_invocation_index:
/* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */
if (s->gfx_level < GFX11 && b->shader->info.stage == MESA_SHADER_VERTEX &&
(s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) {
if (s->options->gfx_level < GFX11 && b->shader->info.stage == MESA_SHADER_VERTEX &&
(s->options->hw_stage == AC_HW_LOCAL_SHADER || s->options->hw_stage == AC_HW_HULL_SHADER)) {
if (!s->vs_rel_patch_id) {
s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id,
s->args->tcs_rel_ids, 255);
}
replacement = s->vs_rel_patch_id;
} else if (s->workgroup_size <= s->wave_size) {
} else if (s->options->workgroup_size <= s->options->wave_size) {
/* Just a subgroup invocation ID. */
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
} else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) {
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), nir_imm_int(b, 0));
} else if (s->options->gfx_level < GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->wave_size == 64) {
/* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just
* feed that to mbcnt. (GFX12 doesn't have tg_size)
*/
nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0);
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64);
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), wave_id_mul_64);
} else {
nir_def *subgroup_id;
/* LLVM uses an intrinsic to get this on gfx12. */
if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->use_llvm) {
if (s->options->gfx_level >= GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->use_llvm) {
subgroup_id = nir_load_subgroup_id(b);
} else {
subgroup_id = load_subgroup_id_lowered(s, b);
}
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size),
nir_imul_imm(b, subgroup_id, s->wave_size));
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size),
nir_imul_imm(b, subgroup_id, s->options->wave_size));
}
break;
case nir_intrinsic_load_subgroup_invocation:
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0));
replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), nir_imm_int(b, 0));
break;
case nir_intrinsic_load_task_ring_entry_amd:
replacement = ac_nir_load_arg(b, s->args, s->args->task_ring_entry);
@ -543,19 +538,12 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state)
}
bool
ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level,
bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage,
unsigned wave_size, unsigned workgroup_size, bool use_llvm,
const struct ac_shader_args *ac_args)
ac_nir_lower_intrinsics_to_args(nir_shader *shader, const struct ac_shader_args *ac_args,
const ac_nir_lower_intrinsics_to_args_options *options)
{
lower_intrinsics_to_args_state state = {
.gfx_level = gfx_level,
.use_llvm = use_llvm,
.hw_stage = hw_stage,
.has_ls_vgpr_init_bug = has_ls_vgpr_init_bug,
.wave_size = wave_size,
.workgroup_size = workgroup_size,
.args = ac_args,
.options = options,
};
return nir_shader_intrinsics_pass(shader, lower_intrinsic_to_arg,

View file

@ -489,10 +489,15 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat
.allow_fp16 = gfx_level >= GFX9,
});
NIR_PASS(_, stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level,
pdev->info.has_ls_vgpr_init_bug && gfx_state && !gfx_state->vs.has_prolog,
radv_select_hw_stage(&stage->info, gfx_level), stage->info.wave_size, stage->info.workgroup_size,
radv_use_llvm_for_stage(pdev, stage->stage), &stage->args.ac);
NIR_PASS(_, stage->nir, ac_nir_lower_intrinsics_to_args, &stage->args.ac,
&(ac_nir_lower_intrinsics_to_args_options){
.gfx_level = gfx_level,
.has_ls_vgpr_init_bug = pdev->info.has_ls_vgpr_init_bug && gfx_state && !gfx_state->vs.has_prolog,
.hw_stage = radv_select_hw_stage(&stage->info, gfx_level),
.wave_size = stage->info.wave_size,
.workgroup_size = stage->info.workgroup_size,
.use_llvm = radv_use_llvm_for_stage(pdev, stage->stage),
});
NIR_PASS(_, stage->nir, radv_nir_lower_abi, gfx_level, stage, gfx_state, pdev->info.address32_hi);
if (!stage->key.optimisations_disabled) {

View file

@ -2377,8 +2377,15 @@ radv_create_gs_copy_shader(struct radv_device *device, struct vk_pipeline_cache
gs_copy_stage.info.user_sgprs_locs = gs_copy_stage.args.user_sgprs_locs;
gs_copy_stage.info.inline_push_constant_mask = gs_copy_stage.args.ac.inline_push_const_mask;
NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, pdev->info.gfx_level, pdev->info.has_ls_vgpr_init_bug,
AC_HW_VERTEX_SHADER, 64, 64, radv_use_llvm_for_stage(pdev, MESA_SHADER_VERTEX), &gs_copy_stage.args.ac);
NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, &gs_copy_stage.args.ac,
&(ac_nir_lower_intrinsics_to_args_options){
.gfx_level = pdev->info.gfx_level,
.has_ls_vgpr_init_bug = pdev->info.has_ls_vgpr_init_bug,
.hw_stage = AC_HW_VERTEX_SHADER,
.wave_size = 64,
.workgroup_size = 64,
.use_llvm = radv_use_llvm_for_stage(pdev, MESA_SHADER_VERTEX)
});
NIR_PASS(_, nir, radv_nir_lower_abi, pdev->info.gfx_level, &gs_copy_stage, gfx_state, pdev->info.address32_hi);
NIR_PASS(_, nir, ac_nir_lower_global_access);

View file

@ -1136,11 +1136,15 @@ static void si_postprocess_nir(struct si_nir_shader_ctx *ctx)
NIR_PASS(progress, nir, nir_lower_int64);
NIR_PASS(progress, nir, nir_lower_fp16_casts, nir_lower_fp16_split_fp64);
NIR_PASS(progress, nir, ac_nir_lower_intrinsics_to_args, sel->screen->info.gfx_level,
sel->screen->info.has_ls_vgpr_init_bug,
si_select_hw_stage(nir->info.stage, key, sel->screen->info.gfx_level),
shader->wave_size, si_get_max_workgroup_size(shader), !nir->info.use_aco_amd,
&ctx->args.ac);
NIR_PASS(progress, nir, ac_nir_lower_intrinsics_to_args, &ctx->args.ac,
&(ac_nir_lower_intrinsics_to_args_options){
.gfx_level = sel->screen->info.gfx_level,
.has_ls_vgpr_init_bug = sel->screen->info.has_ls_vgpr_init_bug,
.hw_stage = si_select_hw_stage(nir->info.stage, key, sel->screen->info.gfx_level),
.wave_size = shader->wave_size,
.workgroup_size = si_get_max_workgroup_size(shader),
.use_llvm = !nir->info.use_aco_amd,
});
/* LLVM keep non-uniform sampler as index, so can't do this in NIR.
* Must be done after si_nir_lower_resource().
@ -1366,9 +1370,15 @@ si_nir_generate_gs_copy_shader(struct si_screen *sscreen,
si_init_shader_args(shader, &linked.consumer.args, &gs_nir->info);
NIR_PASS(_, nir, si_nir_lower_abi, shader, &linked.consumer.args);
NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, sscreen->info.gfx_level,
sscreen->info.has_ls_vgpr_init_bug, AC_HW_VERTEX_SHADER, 64, 64,
!nir->info.use_aco_amd, &linked.consumer.args.ac);
NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, &linked.consumer.args.ac,
&(ac_nir_lower_intrinsics_to_args_options){
.gfx_level = sscreen->info.gfx_level,
.has_ls_vgpr_init_bug = sscreen->info.has_ls_vgpr_init_bug,
.hw_stage = AC_HW_VERTEX_SHADER,
.wave_size = 64,
.workgroup_size = 64,
.use_llvm = !nir->info.use_aco_amd,
});
NIR_PASS(_, nir, ac_nir_lower_global_access);
NIR_PASS(_, nir, nir_lower_int64);