ac/nir/tess: write TCS patch outputs to memory as vec4 stores at the end

This moves per-patch output VMEM stores to the end of the shader where they
execute only once. They are skipped if the whole workgroup discards
all patches.

If tcs_vertices_out == 1, per-patch output VMEM stores use the same lanes
as per-vertex output VMEM stores, which are aligned to 4 or 8 lanes to get
cached bandwidth for the stores.

Previously, per-patch outputs were stored to memory for every store_output
intrinsic in TCS.

Additionally, LDS is no longer allocated for per-patch outputs that are only
written and read by invocation 0, or they are written by all invocations
but not read, and don't have indirect indexing. This reduces LDS usage and
LDS traffic.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34780>
This commit is contained in:
Marek Olšák 2025-04-25 20:51:52 -04:00 committed by Marge Bot
parent c732306c5a
commit fa5e07d5f7
2 changed files with 162 additions and 28 deletions

View file

@ -97,6 +97,7 @@ typedef struct {
/* Generic per-patch slots. */
uint32_t vram_patch_output_mask;
uint32_t lds_patch_output_mask;
uint32_t vgpr_patch_output_mask; /* Hold the output values in VGPRs until the end. */
/* The highest index returned by map_io + 1. */
uint8_t highest_remapped_vram_output;

View file

@ -159,6 +159,10 @@ typedef struct {
nir_variable *tcs_tess_level[2]; /* outer, inner */
/* We can't use uint8_t due to a buggy gcc warning. */
uint16_t tcs_tess_level_chan_mask[2]; /* outer, inner */
/* Same, but for per-patch outputs. */
nir_variable *tcs_per_patch_outputs[MAX_VARYING][8];
uint8_t tcs_per_patch_output_vmem_chan_mask[MAX_VARYING];
} lower_tess_io_state;
typedef struct {
@ -204,17 +208,34 @@ ac_nir_get_tess_io_info(const nir_shader *tcs, const nir_tcs_info *tcs_info, uin
(tess_levels_only_written_by_invoc0 & tess_levels_only_read_by_invoc0) |
tess_levels_defined_by_all_invoc);
uint32_t patch_outputs_dont_need_lds =
tcs->info.patch_outputs_written & ~tcs->info.patch_outputs_read_indirectly &
~tcs->info.patch_outputs_written_indirectly &
((tcs_info->patch_outputs_only_written_by_invoc0 & ~tcs->info.patch_outputs_read) |
(tcs_info->patch_outputs_only_written_by_invoc0 & tcs_info->patch_outputs_only_read_by_invoc0) |
tcs_info->patch_outputs_defined_by_all_invoc);
/* Determine which outputs use LDS. */
io_info->lds_output_mask = (((tcs->info.outputs_read & tcs->info.outputs_written) |
tcs->info.tess.tcs_cross_invocation_outputs_written |
tcs->info.outputs_written_indirectly) & ~TESS_LVL_MASK) |
(tess_levels_written & ~tess_levels_dont_need_lds);
io_info->lds_patch_output_mask = tcs->info.patch_outputs_read & tcs->info.patch_outputs_written;
io_info->lds_patch_output_mask = tcs->info.patch_outputs_written & ~patch_outputs_dont_need_lds;
/* Determine which outputs hold their values in VGPRs. */
io_info->vgpr_output_mask = (tcs->info.outputs_written &
~(tcs->info.tess.tcs_cross_invocation_outputs_written |
tcs->info.outputs_written_indirectly) & ~TESS_LVL_MASK) |
(tess_levels_written &
(tess_levels_defined_by_all_invoc | tess_levels_only_written_by_invoc0));
io_info->vgpr_patch_output_mask = tcs->info.patch_outputs_written &
~tcs->info.patch_outputs_written_indirectly &
(tcs_info->patch_outputs_defined_by_all_invoc |
tcs_info->patch_outputs_only_written_by_invoc0);
/* Each output must have at least 1 bit in vgpr_output_mask or lds_output_mask or both. */
assert(tcs->info.outputs_written == (io_info->vgpr_output_mask | io_info->lds_output_mask));
assert(tcs->info.patch_outputs_written == (io_info->vgpr_patch_output_mask | io_info->lds_patch_output_mask));
io_info->highest_remapped_vram_output = 0;
io_info->highest_remapped_vram_patch_output = 0;
@ -637,26 +658,19 @@ lower_hs_output_store(nir_builder *b,
assert(store_val->bit_size & (16 | 32));
if (write_to_vmem && per_vertex) {
for (unsigned slot = 0; slot < semantics.num_slots; slot++) {
st->tcs_per_vertex_output_vmem_chan_mask[semantics.location + slot] |= write_mask << component;
if (write_to_vmem) {
if (per_vertex) {
for (unsigned slot = 0; slot < semantics.num_slots; slot++)
st->tcs_per_vertex_output_vmem_chan_mask[semantics.location + slot] |= write_mask << component;
} else {
assert(semantics.location >= VARYING_SLOT_PATCH0 && semantics.location <= VARYING_SLOT_PATCH31);
unsigned index = semantics.location - VARYING_SLOT_PATCH0;
for (unsigned slot = 0; slot < semantics.num_slots; slot++)
st->tcs_per_patch_output_vmem_chan_mask[index + slot] |= write_mask << component;
}
}
/* Only store per-patch outputs to memory here. (TODO: do it at the end of the shader) */
if (write_to_vmem && !per_vertex) {
nir_def *vmem_off = hs_per_patch_output_vmem_offset(b, st, semantics.location, component,
nir_get_io_offset_src(intrin)->ssa, 0, NULL);
nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
nir_def *zero = nir_imm_int(b, 0);
AC_NIR_STORE_IO(b, store_val, 0, write_mask, semantics.high_16bits,
nir_store_buffer_amd, hs_ring_tess_offchip, vmem_off, offchip_offset, zero,
.write_mask = store_write_mask, .base = store_const_offset,
.memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
}
if (write_to_lds) {
nir_def *vertex_index = per_vertex ? nir_get_io_arrayed_index_src(intrin)->ssa : NULL;
nir_def *lds_off = hs_output_lds_offset(b, st, semantics.location, component,
@ -675,6 +689,17 @@ lower_hs_output_store(nir_builder *b,
st->tcs_per_vertex_outputs[semantics.location]);
}
if (write_to_vmem && !per_vertex) {
assert(semantics.location >= VARYING_SLOT_PATCH0 && semantics.location <= VARYING_SLOT_PATCH31);
unsigned index = semantics.location - VARYING_SLOT_PATCH0;
if (st->io_info.vgpr_patch_output_mask & BITFIELD_BIT(index)) {
assert(semantics.num_slots == 1);
store_output_variable(b, store_val, write_mask, component, semantics.high_16bits,
st->tcs_per_patch_outputs[index]);
}
}
/* Save tess levels that don't need to be stored in LDS into local variables. */
if (semantics.location == VARYING_SLOT_TESS_LEVEL_INNER ||
semantics.location == VARYING_SLOT_TESS_LEVEL_OUTER) {
@ -695,11 +720,11 @@ lower_hs_output_load(nir_builder *b,
lower_tess_io_state *st)
{
const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
const unsigned component = nir_intrinsic_component(intrin);
if ((io_sem.location == VARYING_SLOT_TESS_LEVEL_INNER ||
io_sem.location == VARYING_SLOT_TESS_LEVEL_OUTER) &&
!tcs_output_needs_lds(intrin, b->shader, st)) {
const unsigned component = nir_intrinsic_component(intrin);
const unsigned num_components = intrin->def.num_components;
const unsigned bit_size = intrin->def.bit_size;
unsigned i = io_sem.location - VARYING_SLOT_TESS_LEVEL_OUTER;
@ -708,13 +733,33 @@ lower_hs_output_load(nir_builder *b,
return nir_extract_bits(b, &var, 1, component * bit_size, num_components, bit_size);
}
if (io_sem.location >= VARYING_SLOT_PATCH0 && io_sem.location <= VARYING_SLOT_PATCH31 &&
!tcs_output_needs_lds(intrin, b->shader, st)) {
/* Return the per-patch output from local variables. */
assert(io_sem.num_slots == 1);
unsigned index = io_sem.location - VARYING_SLOT_PATCH0;
nir_def *comp[4];
for (unsigned i = 0; i < intrin->def.num_components; i++) {
nir_variable **var = &st->tcs_per_patch_outputs[index][component + io_sem.high_16bits * 4];
/* If the first use of the variable is a load, which means the variable hasn't been created yet,
* it's not always undef because we can be inside a loop that initializes the variable later
* in the loop but in an earlier iteration.
*/
comp[i] = nir_load_var(b, get_or_create_output_variable(b, var, intrin->def.bit_size));
}
return nir_vec(b, comp, intrin->def.num_components);
}
/* If an output is not stored by the shader, replace the output load by undef. */
if (!tcs_output_needs_lds(intrin, b->shader, st))
return nir_undef(b, intrin->def.num_components, intrin->def.bit_size);
nir_def *vertex_index = intrin->intrinsic == nir_intrinsic_load_per_vertex_output ?
nir_get_io_arrayed_index_src(intrin)->ssa : NULL;
nir_def *off = hs_output_lds_offset(b, st, io_sem.location, nir_intrinsic_component(intrin),
nir_def *off = hs_output_lds_offset(b, st, io_sem.location, component,
vertex_index, nir_get_io_offset_src(intrin)->ssa);
nir_def *load = NULL;
@ -1193,7 +1238,8 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
nir_builder *b = &builder; /* This is to avoid the & */
/* Insert a barrier to wait for output stores to LDS. */
if (shader->info.outputs_written & ~st->io_info.vgpr_output_mask) {
if (shader->info.outputs_written & ~st->io_info.vgpr_output_mask ||
shader->info.patch_outputs_written & ~st->io_info.vgpr_patch_output_mask) {
mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP;
nir_barrier(b, .execution_scope = scope, .memory_scope = scope,
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
@ -1248,12 +1294,14 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
/* Gather per-vertex output values from local variables and LDS. */
nir_def *outputs[VARYING_SLOT_MAX] = {0};
nir_def *patch_outputs[MAX_VARYING] = {0};
nir_def *invocation_id = nir_load_invocation_id(b);
nir_def *zero = nir_imm_int(b, 0);
/* Don't load per-vertex outputs from LDS if all tess factors are 0. */
/* Don't load per-vertex and per-patch outputs from LDS if all tess factors are 0. */
nir_if *if_not_discarded = nir_push_if(b, nir_ine_imm(b, vote_result, VOTE_RESULT_ALL_TF_ZERO));
{
/* Load per-vertex outputs from LDS or local variables. */
u_foreach_bit64(slot, st->io_info.vram_output_mask & ~TESS_LVL_MASK) {
if (!st->tcs_per_vertex_output_vmem_chan_mask[slot])
continue;
@ -1274,12 +1322,39 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
outputs[slot] = make_vec4(b, comp);
}
/* Load per-patch outputs from LDS or local variables. */
u_foreach_bit(slot, st->io_info.vram_patch_output_mask) {
if (!st->tcs_per_patch_output_vmem_chan_mask[slot])
continue;
nir_def *comp[4] = {0};
/* Gather stored components either from LDS or from local variables. */
if ((shader->info.patch_outputs_written & ~st->io_info.vgpr_patch_output_mask) & BITFIELD_BIT(slot)) {
u_foreach_bit(i, st->tcs_per_patch_output_vmem_chan_mask[slot]) {
nir_def *lds_off = hs_output_lds_offset(b, st, VARYING_SLOT_PATCH0 + slot, i,
NULL, zero);
comp[i] = nir_load_shared(b, 1, 32, lds_off);
}
} else {
u_foreach_bit(i, st->tcs_per_patch_output_vmem_chan_mask[slot]) {
comp[i] = load_output_channel_from_var(b, st->tcs_per_patch_outputs[slot], i);
}
}
patch_outputs[slot] = make_vec4(b, comp);
}
}
nir_pop_if(b, if_not_discarded);
u_foreach_bit64(slot, st->io_info.vram_output_mask & ~TESS_LVL_MASK) {
if (outputs[slot])
outputs[slot] = nir_if_phi(b, outputs[slot], nir_undef(b, 4, 32));
}
u_foreach_bit(slot, st->io_info.vram_patch_output_mask) {
if (patch_outputs[slot])
patch_outputs[slot] = nir_if_phi(b, patch_outputs[slot], nir_undef(b, 4, 32));
}
if (st->gfx_level >= GFX9) {
/* Wrap the whole shader in a conditional block, allowing only TCS (HS) invocations to execute
@ -1302,32 +1377,45 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
if (outputs[slot])
outputs[slot] = nir_if_phi(b, outputs[slot], nir_undef(b, 4, 32));
}
u_foreach_bit(slot, st->io_info.vram_patch_output_mask) {
if (patch_outputs[slot])
patch_outputs[slot] = nir_if_phi(b, patch_outputs[slot], nir_undef(b, 4, 32));
}
}
/* Store per-vertex outputs to memory. */
nir_def *is_tcs_thread = nir_imm_true(b);
nir_def *is_pervertex_store_thread = nir_imm_true(b);
/* Align the EXEC mask to 8 lanes to overwrite whole 128B blocks on GFX10+, or 4 lanes to
* overwrite whole 64B blocks on GFX9.
*
* Per-patch outputs get the same treatment if tcs_vertices_out == 1, using the same
* aligned EXEC.
*
* GFX6-8 can't align the EXEC mask because it's not ~0.
*/
if (st->gfx_level >= GFX9) {
unsigned align = st->gfx_level >= GFX10 ? 8 : 4;
nir_def *num_tcs_threads = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b), 8, 8);
nir_def *aligned_tcs_threads = nir_align_imm(b, num_tcs_threads, align);
is_tcs_thread = nir_is_subgroup_invocation_lt_amd(b, num_tcs_threads);
is_pervertex_store_thread = nir_is_subgroup_invocation_lt_amd(b, aligned_tcs_threads);
}
nir_def *local_invocation_index = nir_load_local_invocation_index(b);
nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
bool patch_outputs_use_vertex_threads = shader->info.tess.tcs_vertices_out == 1;
nir_if *if_perpatch_stores = NULL;
zero = nir_imm_int(b, 0);
nir_if *if_pervertex_stores =
nir_push_if(b, nir_iand(b, is_pervertex_store_thread,
nir_ine_imm(b, vote_result, VOTE_RESULT_ALL_TF_ZERO)));
{
nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
nir_def *local_invocation_index = nir_load_local_invocation_index(b);
nir_def *zero = nir_imm_int(b, 0);
u_foreach_bit64(slot, st->io_info.vram_output_mask & ~TESS_LVL_MASK) {
if (!outputs[slot])
continue;
@ -1342,7 +1430,52 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
.memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
}
}
nir_pop_if(b, if_pervertex_stores);
/* If we don't use vertex threads to store per-patch outputs, i.e. tcs_vertices_out != 1,
* store per-patch outputs in the first invocation of each patch.
*/
if (!patch_outputs_use_vertex_threads) {
nir_pop_if(b, if_pervertex_stores);
if_perpatch_stores =
nir_push_if(b, nir_iand(b, is_tcs_thread,
nir_iand(b, nir_ieq_imm(b, nir_load_invocation_id(b), 0),
nir_ine_imm(b, vote_result, VOTE_RESULT_ALL_TF_ZERO))));
}
{
u_foreach_bit(slot, st->io_info.vram_patch_output_mask) {
if (!patch_outputs[slot])
continue;
nir_def *vmem_off = hs_per_patch_output_vmem_offset(b, st, VARYING_SLOT_PATCH0 + slot, 0, zero, 0,
patch_outputs_use_vertex_threads ?
nir_imul_imm(b, local_invocation_index, 16u) :
NULL);
/* Always store whole vec4s to get cached bandwidth. Non-vec4 stores cause implicit memory loads
* to fill the rest of cache lines with this layout, as well as when a wave doesn't write whole
* 64B (GFX6-9) or 128B (GFX10+) blocks.
*
* A wave gets cached bandwidth for per-patch output stores only in these cases:
* - tcs_vertices_out == 1 and lanes are aligned to 4 (GFX6-9) or 8 (GFX10+) lanes (always done)
* - tcs_vertices_out == 2 or 4 except the last 4 (GFX6-9) or 8 (GFX10+) invocation_id==0 lanes
* if not all lanes are enabled in the last group of 4 or 8 in the last wave
* - tcs_vertices_out == 8 only with wave64 on GFX10+ except the last 8 invocation_id==0 lanes
* if not all lanes are enabled in the last group of 8 in the last wave
* - all full groups of 4 (GFX6-9) or 8 (GFX10+) lanes in the first wave because lane 0 outputs
* of the first wave are always aligned to 256B
*
* Note that the sparsity of invocation_id==0 lanes doesn't matter as long as the whole wave
* covers one or more whole 64B (GFX6-9) or 128B (GFX10+) blocks.
*/
nir_store_buffer_amd(b, patch_outputs[slot], hs_ring_tess_offchip, vmem_off, offchip_offset, zero,
.memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
}
}
if (patch_outputs_use_vertex_threads)
nir_pop_if(b, if_pervertex_stores);
else
nir_pop_if(b, if_perpatch_stores);
nir_progress(true, impl, nir_metadata_none);
}