diff --git a/src/amd/common/nir/ac_nir.h b/src/amd/common/nir/ac_nir.h index 739a5c159c4..6e06568d991 100644 --- a/src/amd/common/nir/ac_nir.h +++ b/src/amd/common/nir/ac_nir.h @@ -89,10 +89,17 @@ ac_nir_load_smem(nir_builder *b, unsigned num_components, nir_def *addr, nir_def bool ac_nir_lower_sin_cos(nir_shader *shader); -bool ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level, - bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage, - unsigned wave_size, unsigned workgroup_size, bool use_llvm, - const struct ac_shader_args *ac_args); +typedef struct { + enum amd_gfx_level gfx_level; + bool has_ls_vgpr_init_bug; + const enum ac_hw_stage hw_stage; + unsigned wave_size; + unsigned workgroup_size; + bool use_llvm; +} ac_nir_lower_intrinsics_to_args_options; + +bool ac_nir_lower_intrinsics_to_args(nir_shader *shader, const struct ac_shader_args *ac_args, + const ac_nir_lower_intrinsics_to_args_options *options); nir_xfb_info *ac_nir_get_sorted_xfb_info(const nir_shader *nir); diff --git a/src/amd/common/nir/ac_nir_lower_intrinsics_to_args.c b/src/amd/common/nir/ac_nir_lower_intrinsics_to_args.c index 5296cc52473..567ab9c9594 100644 --- a/src/amd/common/nir/ac_nir_lower_intrinsics_to_args.c +++ b/src/amd/common/nir/ac_nir_lower_intrinsics_to_args.c @@ -12,12 +12,7 @@ typedef struct { const struct ac_shader_args *const args; - const enum amd_gfx_level gfx_level; - bool use_llvm; - bool has_ls_vgpr_init_bug; - unsigned wave_size; - unsigned workgroup_size; - const enum ac_hw_stage hw_stage; + const ac_nir_lower_intrinsics_to_args_options *options; nir_def *vertex_id; nir_def *instance_id; @@ -36,8 +31,8 @@ preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct a nir_def *value = ac_nir_load_arg_upper_bound(&start_b, s->args, arg, upper_bound); /* If there are no HS threads, SPI mistakenly loads the LS VGPRs starting at VGPR 0. */ - if ((s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER) && - s->has_ls_vgpr_init_bug) { + if ((s->options->hw_stage == AC_HW_LOCAL_SHADER || s->options->hw_stage == AC_HW_HULL_SHADER) && + s->options->has_ls_vgpr_init_bug) { nir_def *count = ac_nir_unpack_arg(&start_b, s->args, s->args->merged_wave_info, 8, 8); nir_def *hs_empty = nir_ieq_imm(&start_b, count, 0); value = nir_bcsel(&start_b, hs_empty, @@ -50,14 +45,14 @@ preload_arg(lower_intrinsics_to_args_state *s, nir_function_impl *impl, struct a static nir_def * load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b) { - if (s->workgroup_size <= s->wave_size) { + if (s->options->workgroup_size <= s->options->wave_size) { return nir_imm_int(b, 0); - } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) { - if (s->gfx_level >= GFX12) { - assert(!s->use_llvm); + } else if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) { + if (s->options->gfx_level >= GFX12) { + assert(!s->options->use_llvm); nir_def *ttmp8 = nir_load_ttmp_register_amd(b, .base = 8); return nir_ubfe_imm(b, ttmp8, 25, 5); - } else if (s->gfx_level >= GFX10_3) { + } else if (s->options->gfx_level >= GFX10_3) { assert(s->args->tg_size.used); return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 20, 5); } else { @@ -68,8 +63,8 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b) assert(s->args->tg_size.used); return ac_nir_unpack_arg(b, s->args, s->args->tg_size, 6, 6); } - } else if (s->hw_stage == AC_HW_HULL_SHADER) { - if (s->gfx_level >= GFX11) { + } else if (s->options->hw_stage == AC_HW_HULL_SHADER) { + if (s->options->gfx_level >= GFX11) { assert(s->args->tcs_wave_id.used); return ac_nir_unpack_arg(b, s->args, s->args->tcs_wave_id, 0, 3); } else if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) { @@ -94,12 +89,12 @@ load_subgroup_id_lowered(lower_intrinsics_to_args_state *s, nir_builder *b) sgpr_local_invocation_index = nir_iadd(b, sgpr_patch_start, sgpr_invocation_id); } } - return nir_ushr_imm(b, sgpr_local_invocation_index, util_logbase2(s->wave_size)); + return nir_ushr_imm(b, sgpr_local_invocation_index, util_logbase2(s->options->wave_size)); } else { UNREACHABLE("unimplemented for LS"); } - } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER || - s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) { + } else if (s->options->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER || + s->options->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) { assert(s->args->merged_wave_info.used); return ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 24, 4); } else { @@ -117,17 +112,17 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) switch (intrin->intrinsic) { case nir_intrinsic_load_subgroup_id: /* LLVM uses an intrinsic to get this on gfx12. */ - if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->use_llvm) + if (s->options->gfx_level >= GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->use_llvm) return false; replacement = load_subgroup_id_lowered(s, b); break; case nir_intrinsic_load_num_subgroups: { - if (s->hw_stage == AC_HW_COMPUTE_SHADER) { + if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) { assert(s->args->tg_size.used); replacement = ac_nir_unpack_arg(b, s->args, s->args->tg_size, 0, 6); - } else if (s->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER || - s->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) { + } else if (s->options->hw_stage == AC_HW_LEGACY_GEOMETRY_SHADER || + s->options->hw_stage == AC_HW_NEXT_GEN_GEOMETRY_SHADER) { assert(s->args->merged_wave_info.used); replacement = ac_nir_unpack_arg(b, s->args, s->args->merged_wave_info, 28, 4); } else { @@ -141,19 +136,19 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) /* This lowering is only valid with fast_launch = 2, otherwise we assume that * lower_workgroup_id_to_index removed any uses of the workgroup id by this point. */ - assert(s->gfx_level >= GFX11); + assert(s->options->gfx_level >= GFX11); nir_def *xy = ac_nir_load_arg(b, s->args, s->args->tess_offchip_offset); nir_def *z = ac_nir_load_arg(b, s->args, s->args->gs_attr_offset); replacement = nir_vec3(b, nir_extract_u16(b, xy, nir_imm_int(b, 0)), nir_extract_u16(b, xy, nir_imm_int(b, 1)), nir_extract_u16(b, z, nir_imm_int(b, 1))); - } else if (s->hw_stage == AC_HW_COMPUTE_SHADER) { + } else if (s->options->hw_stage == AC_HW_COMPUTE_SHADER) { nir_def *undef = nir_undef(b, 1, 32); nir_def *ids[3] = {undef, undef, undef}; - if (s->gfx_level >= GFX12) { + if (s->options->gfx_level >= GFX12) { /* LLVM uses intrinsics to get workgroup IDs on gfx12. */ - if (s->use_llvm) + if (s->options->use_llvm) return false; if (s->args->workgroup_ids[0].used) @@ -318,9 +313,9 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) if (b->shader->info.stage == MESA_SHADER_TESS_CTRL) { replacement = ac_nir_unpack_arg(b, s->args, s->args->tcs_rel_ids, 8, 5); } else if (b->shader->info.stage == MESA_SHADER_GEOMETRY) { - if (s->gfx_level >= GFX12) { + if (s->options->gfx_level >= GFX12) { replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_vtx_offset[0], 27, 5); - } else if (s->gfx_level >= GFX10) { + } else if (s->options->gfx_level >= GFX10) { replacement = ac_nir_unpack_arg(b, s->args, s->args->gs_invocation_id, 0, 5); } else { replacement = ac_nir_load_arg_upper_bound(b, s->args, s->args->gs_invocation_id, 31); @@ -359,7 +354,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) break; case nir_intrinsic_load_layer_id: replacement = ac_nir_unpack_arg(b, s->args, s->args->ancillary, - 16, s->gfx_level >= GFX12 ? 14 : 13); + 16, s->options->gfx_level >= GFX12 ? 14 : 13); break; case nir_intrinsic_load_barycentric_optimize_amd: { nir_def *prim_mask = ac_nir_load_arg(b, s->args, s->args->prim_mask); @@ -467,7 +462,7 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) replacement = s->tes_patch_id ? s->tes_patch_id : ac_nir_load_arg(b, s->args, s->args->tes_patch_id); } else if (b->shader->info.stage == MESA_SHADER_VERTEX) { - if (s->hw_stage == AC_HW_VERTEX_SHADER) + if (s->options->hw_stage == AC_HW_VERTEX_SHADER) replacement = ac_nir_load_arg(b, s->args, s->args->vs_prim_id); /* legacy */ else replacement = ac_nir_load_arg(b, s->args, s->args->gs_prim_id); /* NGG */ @@ -490,38 +485,38 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) } case nir_intrinsic_load_local_invocation_index: /* GFX11 HS has subgroup_id, so use it instead of vs_rel_patch_id. */ - if (s->gfx_level < GFX11 && b->shader->info.stage == MESA_SHADER_VERTEX && - (s->hw_stage == AC_HW_LOCAL_SHADER || s->hw_stage == AC_HW_HULL_SHADER)) { + if (s->options->gfx_level < GFX11 && b->shader->info.stage == MESA_SHADER_VERTEX && + (s->options->hw_stage == AC_HW_LOCAL_SHADER || s->options->hw_stage == AC_HW_HULL_SHADER)) { if (!s->vs_rel_patch_id) { s->vs_rel_patch_id = preload_arg(s, b->impl, s->args->vs_rel_patch_id, s->args->tcs_rel_ids, 255); } replacement = s->vs_rel_patch_id; - } else if (s->workgroup_size <= s->wave_size) { + } else if (s->options->workgroup_size <= s->options->wave_size) { /* Just a subgroup invocation ID. */ - replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0)); - } else if (s->gfx_level < GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->wave_size == 64) { + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), nir_imm_int(b, 0)); + } else if (s->options->gfx_level < GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->wave_size == 64) { /* After the AND the bits are already multiplied by 64 (left shifted by 6) so we can just * feed that to mbcnt. (GFX12 doesn't have tg_size) */ nir_def *wave_id_mul_64 = nir_iand_imm(b, ac_nir_load_arg(b, s->args, s->args->tg_size), 0xfc0); - replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), wave_id_mul_64); + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), wave_id_mul_64); } else { nir_def *subgroup_id; /* LLVM uses an intrinsic to get this on gfx12. */ - if (s->gfx_level >= GFX12 && s->hw_stage == AC_HW_COMPUTE_SHADER && s->use_llvm) { + if (s->options->gfx_level >= GFX12 && s->options->hw_stage == AC_HW_COMPUTE_SHADER && s->options->use_llvm) { subgroup_id = nir_load_subgroup_id(b); } else { subgroup_id = load_subgroup_id_lowered(s, b); } - replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), - nir_imul_imm(b, subgroup_id, s->wave_size)); + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), + nir_imul_imm(b, subgroup_id, s->options->wave_size)); } break; case nir_intrinsic_load_subgroup_invocation: - replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->wave_size), nir_imm_int(b, 0)); + replacement = nir_mbcnt_amd(b, nir_imm_intN_t(b, ~0ull, s->options->wave_size), nir_imm_int(b, 0)); break; case nir_intrinsic_load_task_ring_entry_amd: replacement = ac_nir_load_arg(b, s->args, s->args->task_ring_entry); @@ -543,19 +538,12 @@ lower_intrinsic_to_arg(nir_builder *b, nir_intrinsic_instr *intrin, void *state) } bool -ac_nir_lower_intrinsics_to_args(nir_shader *shader, const enum amd_gfx_level gfx_level, - bool has_ls_vgpr_init_bug, const enum ac_hw_stage hw_stage, - unsigned wave_size, unsigned workgroup_size, bool use_llvm, - const struct ac_shader_args *ac_args) +ac_nir_lower_intrinsics_to_args(nir_shader *shader, const struct ac_shader_args *ac_args, + const ac_nir_lower_intrinsics_to_args_options *options) { lower_intrinsics_to_args_state state = { - .gfx_level = gfx_level, - .use_llvm = use_llvm, - .hw_stage = hw_stage, - .has_ls_vgpr_init_bug = has_ls_vgpr_init_bug, - .wave_size = wave_size, - .workgroup_size = workgroup_size, .args = ac_args, + .options = options, }; return nir_shader_intrinsics_pass(shader, lower_intrinsic_to_arg, diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index 4a6ea9988a4..d753ca30ed8 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -489,10 +489,15 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat .allow_fp16 = gfx_level >= GFX9, }); - NIR_PASS(_, stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level, - pdev->info.has_ls_vgpr_init_bug && gfx_state && !gfx_state->vs.has_prolog, - radv_select_hw_stage(&stage->info, gfx_level), stage->info.wave_size, stage->info.workgroup_size, - radv_use_llvm_for_stage(pdev, stage->stage), &stage->args.ac); + NIR_PASS(_, stage->nir, ac_nir_lower_intrinsics_to_args, &stage->args.ac, + &(ac_nir_lower_intrinsics_to_args_options){ + .gfx_level = gfx_level, + .has_ls_vgpr_init_bug = pdev->info.has_ls_vgpr_init_bug && gfx_state && !gfx_state->vs.has_prolog, + .hw_stage = radv_select_hw_stage(&stage->info, gfx_level), + .wave_size = stage->info.wave_size, + .workgroup_size = stage->info.workgroup_size, + .use_llvm = radv_use_llvm_for_stage(pdev, stage->stage), + }); NIR_PASS(_, stage->nir, radv_nir_lower_abi, gfx_level, stage, gfx_state, pdev->info.address32_hi); if (!stage->key.optimisations_disabled) { diff --git a/src/amd/vulkan/radv_pipeline_graphics.c b/src/amd/vulkan/radv_pipeline_graphics.c index 4209afe2f0e..c4a21a2b7ab 100644 --- a/src/amd/vulkan/radv_pipeline_graphics.c +++ b/src/amd/vulkan/radv_pipeline_graphics.c @@ -2377,8 +2377,15 @@ radv_create_gs_copy_shader(struct radv_device *device, struct vk_pipeline_cache gs_copy_stage.info.user_sgprs_locs = gs_copy_stage.args.user_sgprs_locs; gs_copy_stage.info.inline_push_constant_mask = gs_copy_stage.args.ac.inline_push_const_mask; - NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, pdev->info.gfx_level, pdev->info.has_ls_vgpr_init_bug, - AC_HW_VERTEX_SHADER, 64, 64, radv_use_llvm_for_stage(pdev, MESA_SHADER_VERTEX), &gs_copy_stage.args.ac); + NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, &gs_copy_stage.args.ac, + &(ac_nir_lower_intrinsics_to_args_options){ + .gfx_level = pdev->info.gfx_level, + .has_ls_vgpr_init_bug = pdev->info.has_ls_vgpr_init_bug, + .hw_stage = AC_HW_VERTEX_SHADER, + .wave_size = 64, + .workgroup_size = 64, + .use_llvm = radv_use_llvm_for_stage(pdev, MESA_SHADER_VERTEX) + }); NIR_PASS(_, nir, radv_nir_lower_abi, pdev->info.gfx_level, &gs_copy_stage, gfx_state, pdev->info.address32_hi); NIR_PASS(_, nir, ac_nir_lower_global_access); diff --git a/src/gallium/drivers/radeonsi/si_shader.c b/src/gallium/drivers/radeonsi/si_shader.c index 0acd37d6e8b..af73c5b8ae6 100644 --- a/src/gallium/drivers/radeonsi/si_shader.c +++ b/src/gallium/drivers/radeonsi/si_shader.c @@ -1136,11 +1136,15 @@ static void si_postprocess_nir(struct si_nir_shader_ctx *ctx) NIR_PASS(progress, nir, nir_lower_int64); NIR_PASS(progress, nir, nir_lower_fp16_casts, nir_lower_fp16_split_fp64); - NIR_PASS(progress, nir, ac_nir_lower_intrinsics_to_args, sel->screen->info.gfx_level, - sel->screen->info.has_ls_vgpr_init_bug, - si_select_hw_stage(nir->info.stage, key, sel->screen->info.gfx_level), - shader->wave_size, si_get_max_workgroup_size(shader), !nir->info.use_aco_amd, - &ctx->args.ac); + NIR_PASS(progress, nir, ac_nir_lower_intrinsics_to_args, &ctx->args.ac, + &(ac_nir_lower_intrinsics_to_args_options){ + .gfx_level = sel->screen->info.gfx_level, + .has_ls_vgpr_init_bug = sel->screen->info.has_ls_vgpr_init_bug, + .hw_stage = si_select_hw_stage(nir->info.stage, key, sel->screen->info.gfx_level), + .wave_size = shader->wave_size, + .workgroup_size = si_get_max_workgroup_size(shader), + .use_llvm = !nir->info.use_aco_amd, + }); /* LLVM keep non-uniform sampler as index, so can't do this in NIR. * Must be done after si_nir_lower_resource(). @@ -1366,9 +1370,15 @@ si_nir_generate_gs_copy_shader(struct si_screen *sscreen, si_init_shader_args(shader, &linked.consumer.args, &gs_nir->info); NIR_PASS(_, nir, si_nir_lower_abi, shader, &linked.consumer.args); - NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, sscreen->info.gfx_level, - sscreen->info.has_ls_vgpr_init_bug, AC_HW_VERTEX_SHADER, 64, 64, - !nir->info.use_aco_amd, &linked.consumer.args.ac); + NIR_PASS(_, nir, ac_nir_lower_intrinsics_to_args, &linked.consumer.args.ac, + &(ac_nir_lower_intrinsics_to_args_options){ + .gfx_level = sscreen->info.gfx_level, + .has_ls_vgpr_init_bug = sscreen->info.has_ls_vgpr_init_bug, + .hw_stage = AC_HW_VERTEX_SHADER, + .wave_size = 64, + .workgroup_size = 64, + .use_llvm = !nir->info.use_aco_amd, + }); NIR_PASS(_, nir, ac_nir_lower_global_access); NIR_PASS(_, nir, nir_lower_int64);