From c678844ccb0dfdccf344edfbe032102419ee9cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ol=C5=A1=C3=A1k?= Date: Wed, 23 Apr 2025 13:52:39 -0400 Subject: [PATCH] ac/nir/tess: move LDS and VMEM output masks into a new info structure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This will replace LDS and VMEM output size computations in drivers. Reviewed-by: Timur Kristóf Part-of: --- src/amd/common/nir/ac_nir.h | 14 +++ .../common/nir/ac_nir_lower_tess_io_to_mem.c | 94 +++++++++++-------- 2 files changed, 69 insertions(+), 39 deletions(-) diff --git a/src/amd/common/nir/ac_nir.h b/src/amd/common/nir/ac_nir.h index dc48a606be2..0912021023e 100644 --- a/src/amd/common/nir/ac_nir.h +++ b/src/amd/common/nir/ac_nir.h @@ -89,6 +89,20 @@ bool ac_nir_optimize_outputs(nir_shader *nir, bool sprite_tex_disallowed, int8_t slot_remap[NUM_TOTAL_VARYING_SLOTS], uint8_t param_export_index[NUM_TOTAL_VARYING_SLOTS]); +typedef struct { + /* Per-vertex slots and tess levels. */ + uint64_t vram_output_mask; + uint64_t lds_output_mask; + uint64_t vgpr_output_mask; /* Hold the output values in VGPRs until the end. */ + /* Generic per-patch slots. */ + uint32_t vram_patch_output_mask; + uint32_t lds_patch_output_mask; +} ac_nir_tess_io_info; + +void +ac_nir_get_tess_io_info(const nir_shader *tcs, const nir_tcs_info *tcs_info, uint64_t tes_inputs_read, + uint32_t tes_patch_inputs_read, ac_nir_tess_io_info *io_info); + bool ac_nir_lower_ls_outputs_to_mem(nir_shader *ls, ac_nir_map_io_driver_location map, diff --git a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c index 9932c087818..a5f615a662d 100644 --- a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c +++ b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c @@ -112,6 +112,7 @@ typedef struct { /* Which hardware generation we're dealing with */ enum amd_gfx_level gfx_level; nir_tcs_info tcs_info; + ac_nir_tess_io_info io_info; /* I/O semantic -> real location used by lowering. */ ac_nir_map_io_driver_location map_io; @@ -170,22 +171,39 @@ typedef struct { #define TESS_LVL_MASK (VARYING_BIT_TESS_LEVEL_OUTER | VARYING_BIT_TESS_LEVEL_INNER) +void +ac_nir_get_tess_io_info(const nir_shader *tcs, const nir_tcs_info *tcs_info, uint64_t tes_inputs_read, + uint32_t tes_patch_inputs_read, ac_nir_tess_io_info *io_info) +{ + io_info->vram_output_mask = tcs->info.outputs_written & tes_inputs_read; + io_info->vram_patch_output_mask = tcs->info.patch_outputs_written & tes_patch_inputs_read; + io_info->lds_output_mask = (((tcs->info.outputs_read & tcs->info.outputs_written) | + tcs->info.tess.tcs_cross_invocation_outputs_written | + tcs->info.outputs_written_indirectly) & ~TESS_LVL_MASK) | + (tcs_info->all_invocations_define_tess_levels ? + 0 : (tcs->info.outputs_written & TESS_LVL_MASK)); + io_info->lds_patch_output_mask = tcs->info.patch_outputs_read & tcs->info.patch_outputs_written; + io_info->vgpr_output_mask = (tcs->info.outputs_written & + ~(tcs->info.tess.tcs_cross_invocation_outputs_written | + tcs->info.outputs_written_indirectly) & ~TESS_LVL_MASK); +} + static uint64_t -tcs_vram_per_vtx_out_mask(nir_shader *shader, lower_tess_io_state *st) +tcs_vram_per_vtx_out_mask(lower_tess_io_state *st) { - return st->tes_inputs_read & ~TESS_LVL_MASK; + return st->io_info.vram_output_mask & ~TESS_LVL_MASK; } static uint32_t -tcs_vram_tf_out_mask(nir_shader *shader, lower_tess_io_state *st) +tcs_vram_tf_out_mask(lower_tess_io_state *st) { - return st->tes_inputs_read & TESS_LVL_MASK; + return st->io_info.vram_output_mask & TESS_LVL_MASK; } static uint32_t -tcs_vram_per_patch_out_mask(nir_shader *shader, lower_tess_io_state *st) +tcs_vram_per_patch_out_mask(lower_tess_io_state *st) { - return st->tes_patch_inputs_read; + return st->io_info.vram_patch_output_mask; } static bool @@ -202,33 +220,30 @@ tcs_output_needs_vmem(nir_intrinsic_instr *intrin, intrin->intrinsic == nir_intrinsic_load_per_vertex_output; if (per_vertex) { - return tcs_vram_per_vtx_out_mask(shader, st) & BITFIELD64_BIT(loc); + return tcs_vram_per_vtx_out_mask(st) & BITFIELD64_BIT(loc); } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) { return false; } else { - return tcs_vram_per_patch_out_mask(shader, st) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0); + return tcs_vram_per_patch_out_mask(st) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0); } } static uint64_t -tcs_lds_per_vtx_out_mask(nir_shader *shader) +tcs_lds_per_vtx_out_mask(lower_tess_io_state *st) { - return ((shader->info.outputs_read & shader->info.outputs_written) | - shader->info.tess.tcs_cross_invocation_outputs_written | - shader->info.outputs_written_indirectly) & ~TESS_LVL_MASK; + return st->io_info.lds_output_mask & ~TESS_LVL_MASK; } static uint64_t -tcs_lds_tf_out_mask(nir_shader *shader, lower_tess_io_state *st) +tcs_lds_tf_out_mask(lower_tess_io_state *st) { - return st->tcs_info.all_invocations_define_tess_levels ? - 0ull : (shader->info.outputs_written & TESS_LVL_MASK); + return st->io_info.lds_output_mask & TESS_LVL_MASK; } static uint32_t -tcs_lds_per_patch_out_mask(nir_shader *shader) +tcs_lds_per_patch_out_mask(lower_tess_io_state *st) { - return shader->info.patch_outputs_read & shader->info.patch_outputs_written; + return st->io_info.lds_patch_output_mask; } static bool @@ -241,11 +256,11 @@ tcs_output_needs_lds(nir_intrinsic_instr *intrin, intrin->intrinsic == nir_intrinsic_load_per_vertex_output; if (per_vertex) { - return tcs_lds_per_vtx_out_mask(shader) & BITFIELD64_BIT(loc); + return tcs_lds_per_vtx_out_mask(st) & BITFIELD64_BIT(loc); } else if (loc == VARYING_SLOT_TESS_LEVEL_OUTER || loc == VARYING_SLOT_TESS_LEVEL_INNER) { - return tcs_lds_tf_out_mask(shader, st) & BITFIELD64_BIT(loc); + return tcs_lds_tf_out_mask(st) & BITFIELD64_BIT(loc); } else { - return tcs_lds_per_patch_out_mask(shader) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0); + return tcs_lds_per_patch_out_mask(st) & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0); } } @@ -379,18 +394,18 @@ hs_output_lds_map_io_location(nir_shader *shader, lower_tess_io_state *st) { if (!per_vertex) { - const uint64_t tf_mask = tcs_lds_tf_out_mask(shader, st); + const uint64_t tf_mask = tcs_lds_tf_out_mask(st); if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) { assert(tf_mask & BITFIELD64_BIT(loc)); return util_bitcount64(tf_mask & BITFIELD64_MASK(loc)); } - const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(shader); + const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(st); assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0)); return util_bitcount64(tf_mask) + util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0)); } else { - const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(shader); + const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(st); assert(per_vertex_mask & BITFIELD64_BIT(loc)); return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc)); } @@ -400,9 +415,9 @@ static nir_def * hs_output_lds_offset(nir_builder *b, lower_tess_io_state *st, unsigned location, unsigned component, nir_def *vertex_index, nir_def *io_offset) { - const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(b->shader); - const uint64_t tf_mask = tcs_lds_tf_out_mask(b->shader, st); - const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(b->shader); + const uint64_t per_vertex_mask = tcs_lds_per_vtx_out_mask(st); + const uint64_t tf_mask = tcs_lds_tf_out_mask(st); + const uint32_t patch_out_mask = tcs_lds_per_patch_out_mask(st); unsigned tcs_num_reserved_outputs = util_bitcount64(per_vertex_mask); unsigned tcs_num_reserved_patch_outputs = util_bitcount64(tf_mask) + util_bitcount(patch_out_mask); @@ -463,18 +478,18 @@ hs_output_vram_map_io_location(nir_shader *shader, * Map varyings to a prefix sum of the IO mask to save space in VRAM. */ if (!per_vertex) { - const uint64_t tf_mask = tcs_vram_tf_out_mask(shader, st); + const uint64_t tf_mask = tcs_vram_tf_out_mask(st); if (loc == VARYING_SLOT_TESS_LEVEL_INNER || loc == VARYING_SLOT_TESS_LEVEL_OUTER) { assert(tf_mask & BITFIELD64_BIT(loc)); return util_bitcount64(tf_mask & BITFIELD64_MASK(loc)); } - const uint32_t patch_out_mask = tcs_vram_per_patch_out_mask(shader, st); + const uint32_t patch_out_mask = tcs_vram_per_patch_out_mask(st); assert(patch_out_mask & BITFIELD_BIT(loc - VARYING_SLOT_PATCH0)); return util_bitcount64(tf_mask) + util_bitcount(patch_out_mask & BITFIELD_MASK(loc - VARYING_SLOT_PATCH0)); } else { - const uint64_t per_vertex_mask = tcs_vram_per_vtx_out_mask(shader, st); + const uint64_t per_vertex_mask = tcs_vram_per_vtx_out_mask(st); assert(per_vertex_mask & BITFIELD64_BIT(loc)); return util_bitcount64(per_vertex_mask & BITFIELD64_MASK(loc)); } @@ -623,8 +638,7 @@ lower_hs_output_store(nir_builder *b, /* Store per-vertex outputs to temp variables. The outputs will be stored to memory at the end of the shader. */ if (write_to_vmem && per_vertex && - !((b->shader->info.tess.tcs_cross_invocation_outputs_written | - b->shader->info.outputs_written_indirectly) & BITFIELD64_BIT(semantics.location))) { + st->io_info.vgpr_output_mask & BITFIELD64_BIT(semantics.location)) { assert(semantics.location < ARRAY_SIZE(st->tcs_per_vertex_outputs)); assert(semantics.num_slots == 1); @@ -1162,8 +1176,7 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) /* Insert a barrier to wait for output stores to LDS. */ if (!st->tcs_info.all_invocations_define_tess_levels || - shader->info.tess.tcs_cross_invocation_outputs_written || - shader->info.outputs_written_indirectly) { + shader->info.outputs_written & ~st->io_info.vgpr_output_mask) { mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP; nir_barrier(b, .execution_scope = scope, .memory_scope = scope, .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared); @@ -1224,15 +1237,14 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) /* Don't load per-vertex outputs from LDS if all tess factors are 0. */ nir_if *if_not_discarded = nir_push_if(b, nir_ine_imm(b, vote_result, VOTE_RESULT_ALL_TF_ZERO)); { - u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) { + u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(st)) { if (!st->tcs_per_vertex_output_vmem_chan_mask[slot]) continue; nir_def *comp[4] = {0}; /* Gather stored components either from LDS or from local variables. */ - if ((shader->info.tess.tcs_cross_invocation_outputs_written | - shader->info.outputs_written_indirectly) & BITFIELD64_BIT(slot)) { + if ((shader->info.outputs_written & ~st->io_info.vgpr_output_mask) & BITFIELD64_BIT(slot)) { u_foreach_bit(i, st->tcs_per_vertex_output_vmem_chan_mask[slot]) { nir_def *lds_off = hs_output_lds_offset(b, st, slot, i, invocation_id, zero); comp[i] = nir_load_shared(b, 1, 32, lds_off); @@ -1247,7 +1259,7 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) } } nir_pop_if(b, if_not_discarded); - u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) { + u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(st)) { if (outputs[slot]) outputs[slot] = nir_if_phi(b, outputs[slot], nir_undef(b, 4, 32)); } @@ -1269,7 +1281,7 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) nir_pop_if(b, if_tcs); vote_result = nir_if_phi(b, vote_result, nir_undef(b, 1, 32)); /* no-op, it should be an SGPR */ - u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) { + u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(st)) { if (outputs[slot]) outputs[slot] = nir_if_phi(b, outputs[slot], nir_undef(b, 4, 32)); } @@ -1299,7 +1311,7 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) nir_def *local_invocation_index = nir_load_local_invocation_index(b); nir_def *zero = nir_imm_int(b, 0); - u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) { + u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(st)) { if (!outputs[slot]) continue; @@ -1453,6 +1465,8 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info, .map_io = map, }; + ac_nir_get_tess_io_info(shader, info, tes_inputs_read, tes_patch_inputs_read, &state.io_info); + if (state.tcs_info.all_invocations_define_tess_levels) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); state.tcs_tess_level_outer = @@ -1483,6 +1497,8 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader, assert(shader->info.stage == MESA_SHADER_TESS_EVAL); lower_tess_io_state state = { + .io_info.vram_output_mask = shader->info.inputs_read, + .io_info.vram_patch_output_mask = shader->info.patch_inputs_read, .map_io = map, .tes_inputs_read = shader->info.inputs_read, .tes_patch_inputs_read = shader->info.patch_inputs_read,