ac/nir/tess: write TCS per-vertex outputs to memory as vec4 stores at the end

This improves write throughput for TCS outputs. It follows the same idea
as attribute stores in hw GS. The improvement is easily measurable with
a microbenchmark.

It also has the advantage that multiple output stores to the same address
don't result in multiple memory stores. Each output components gets only
one memory store at the end of the shader.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34780>
This commit is contained in:
Marek Olšák 2025-04-19 02:26:46 -04:00 committed by Marge Bot
parent 509f0e62ad
commit 9c16228359
9 changed files with 206 additions and 25 deletions

View file

@ -119,7 +119,9 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
void
ac_nir_compute_tess_wg_info(const struct radeon_info *info, uint64_t outputs_read, uint64_t outputs_written,
uint32_t patch_outputs_read, uint32_t patch_outputs_written, unsigned tcs_vertices_out,
uint32_t patch_outputs_read, uint32_t patch_outputs_written,
uint64_t tcs_cross_invocation_outputs_written,
uint64_t outputs_accessed_indirectly, unsigned tcs_vertices_out,
unsigned wave_size, bool tess_uses_primid, bool all_invocations_define_tess_levels,
unsigned num_tcs_input_cp, unsigned lds_input_vertex_size,
unsigned num_mem_tcs_outputs, unsigned num_mem_tcs_patch_outputs,

View file

@ -152,6 +152,15 @@ typedef struct {
nir_variable *tcs_tess_level_inner;
unsigned tcs_tess_level_outer_mask;
unsigned tcs_tess_level_inner_mask;
/* TCS output values, 8 channels per slot. The last 4 channels are high 16 bits of the first 4 channels.
* Output values that are not stored with cross-invocation access and indirect indexing are stored here.
* Output values stored with cross-invocation access or indirect indexing are stored in LDS.
* All outputs are loaded from LDS or VGPRs and written to memory at the end of the shader.
*/
nir_variable *tcs_per_vertex_outputs[VARYING_SLOT_MAX][8];
/* Max. 4 channels, always 32 bits per channel. */
uint8_t tcs_per_vertex_output_vmem_chan_mask[VARYING_SLOT_MAX];
} lower_tess_io_state;
typedef struct {
@ -204,7 +213,9 @@ tcs_output_needs_vmem(nir_intrinsic_instr *intrin,
static uint64_t
tcs_lds_per_vtx_out_mask(nir_shader *shader)
{
return shader->info.outputs_read & shader->info.outputs_written & ~TESS_LVL_MASK;
return ((shader->info.outputs_read & shader->info.outputs_written) |
shader->info.tess.tcs_cross_invocation_outputs_written |
shader->info.outputs_written_indirectly) & ~TESS_LVL_MASK;
}
static uint64_t
@ -537,6 +548,34 @@ lower_hs_per_vertex_input_load(nir_builder *b,
return load;
}
static nir_variable *
get_or_create_output_variable(nir_builder *b, nir_variable **var, unsigned bit_size)
{
/* Create the local variable if needed. */
if (!*var) {
*var = nir_local_variable_create(b->impl, bit_size == 16 ? &glsl_type_builtin_float16_t :
&glsl_type_builtin_float, NULL);
}
return *var;
}
static void
store_output_variable(nir_builder *b, nir_def *store_val, unsigned write_mask, unsigned component,
bool high_16bits, nir_variable **slot)
{
u_foreach_bit(i, write_mask << component) {
assert(!slot[i] ||
glsl_base_type_bit_size(glsl_get_base_type(slot[i]->type)) == store_val->bit_size);
assert((store_val->bit_size == 16 &&
(!slot[4 + i] ||
glsl_base_type_bit_size(glsl_get_base_type(slot[4 + i]->type)) == store_val->bit_size)) ||
(store_val->bit_size == 32 && !slot[4 + i]));
nir_store_var(b, get_or_create_output_variable(b, &slot[i + high_16bits * 4], store_val->bit_size),
nir_channel(b, store_val, i - component), 0x1);
}
}
static nir_def *
lower_hs_output_store(nir_builder *b,
nir_intrinsic_instr *intrin,
@ -553,13 +592,18 @@ lower_hs_output_store(nir_builder *b,
const bool write_to_vmem = tcs_output_needs_vmem(intrin, b->shader, st);
const bool write_to_lds = tcs_output_needs_lds(intrin, b->shader, st);
if (write_to_vmem) {
nir_def *vmem_off = intrin->intrinsic == nir_intrinsic_store_per_vertex_output
? hs_per_vertex_output_vmem_offset(b, st, semantics.location, component,
nir_get_io_arrayed_index_src(intrin)->ssa,
nir_get_io_offset_src(intrin)->ssa, NULL)
: hs_per_patch_output_vmem_offset(b, st, semantics.location, component,
nir_get_io_offset_src(intrin)->ssa, 0, NULL);
assert(store_val->bit_size & (16 | 32));
if (write_to_vmem && per_vertex) {
for (unsigned slot = 0; slot < semantics.num_slots; slot++) {
st->tcs_per_vertex_output_vmem_chan_mask[semantics.location + slot] |= write_mask << component;
}
}
/* Only store per-patch outputs to memory here. (TODO: do it at the end of the shader) */
if (write_to_vmem && !per_vertex) {
nir_def *vmem_off = hs_per_patch_output_vmem_offset(b, st, semantics.location, component,
nir_get_io_offset_src(intrin)->ssa, 0, NULL);
nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
@ -578,6 +622,17 @@ lower_hs_output_store(nir_builder *b,
nir_store_shared, lds_off, .write_mask = store_write_mask, .base = store_const_offset);
}
/* Store per-vertex outputs to temp variables. The outputs will be stored to memory at the end of the shader. */
if (write_to_vmem && per_vertex &&
!((b->shader->info.tess.tcs_cross_invocation_outputs_written |
b->shader->info.outputs_written_indirectly) & BITFIELD64_BIT(semantics.location))) {
assert(semantics.location < ARRAY_SIZE(st->tcs_per_vertex_outputs));
assert(semantics.num_slots == 1);
store_output_variable(b, store_val, write_mask, component, semantics.high_16bits,
st->tcs_per_vertex_outputs[semantics.location]);
}
/* Save tess factor to be used by tess factor writer or reconstruct
* store output instruction later.
*/
@ -1049,6 +1104,36 @@ hs_store_tess_factors_for_tes(nir_builder *b, tess_levels tessfactors, lower_tes
}
}
static nir_def *
make_vec4(nir_builder *b, nir_def *comp[4])
{
for (unsigned i = 0; i < 4; i++) {
if (!comp[i])
comp[i] = nir_undef(b, 1, 32);
}
return nir_vec(b, comp, 4);
}
static nir_def *
load_output_channel_from_var(nir_builder *b, nir_variable *vec[8], unsigned chan)
{
nir_def *lo = NULL, *hi = NULL;
/* It can be one 32-bit value or two 16-bit values. */
if (vec[chan])
lo = nir_load_var(b, vec[chan]);
if (vec[4 + chan])
hi = nir_load_var(b, vec[4 + chan]);
if (lo && hi)
return nir_pack_32_2x16_split(b, lo, hi);
else if (hi)
return nir_ishl_imm(b, nir_u2u32(b, hi), 16);
else
return nir_u2u32(b, lo);
}
static void
hs_finale(nir_shader *shader, lower_tess_io_state *st)
{
@ -1060,8 +1145,10 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
nir_builder builder = nir_builder_at(nir_after_block(last_block));
nir_builder *b = &builder; /* This is to avoid the & */
/* If tess factors are loaded from LDS, wait for their LDS stores. */
if (!st->tcs_info.all_invocations_define_tess_levels) {
/* Insert a barrier to wait for output stores to LDS. */
if (!st->tcs_info.all_invocations_define_tess_levels ||
shader->info.tess.tcs_cross_invocation_outputs_written ||
shader->info.outputs_written_indirectly) {
mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP;
nir_barrier(b, .execution_scope = scope, .memory_scope = scope,
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
@ -1118,6 +1205,33 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
}
nir_pop_if(b, if_invocation_id_zero);
/* Gather per-vertex output values from local variables and LDS. */
nir_def *outputs[VARYING_SLOT_MAX] = {0};
nir_def *invocation_id = nir_load_invocation_id(b);
nir_def *zero = nir_imm_int(b, 0);
u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) {
if (!st->tcs_per_vertex_output_vmem_chan_mask[slot])
continue;
nir_def *comp[4] = {0};
/* Gather stored components either from LDS or from local variables. */
if ((shader->info.tess.tcs_cross_invocation_outputs_written |
shader->info.outputs_written_indirectly) & BITFIELD64_BIT(slot)) {
u_foreach_bit(i, st->tcs_per_vertex_output_vmem_chan_mask[slot]) {
nir_def *lds_off = hs_output_lds_offset(b, st, slot, i, invocation_id, zero);
comp[i] = nir_load_shared(b, 1, 32, lds_off);
}
} else {
u_foreach_bit(i, st->tcs_per_vertex_output_vmem_chan_mask[slot]) {
comp[i] = load_output_channel_from_var(b, st->tcs_per_vertex_outputs[slot], i);
}
}
outputs[slot] = make_vec4(b, comp);
}
if (st->gfx_level >= GFX9) {
/* Wrap the whole shader in a conditional block, allowing only TCS (HS) invocations to execute
* in the LS-HS workgroup.
@ -1133,8 +1247,51 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
nir_cf_reinsert(extracted, b->cursor);
}
nir_pop_if(b, if_tcs);
u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) {
if (outputs[slot])
outputs[slot] = nir_if_phi(b, outputs[slot], nir_undef(b, 4, 32));
}
}
/* Store per-vertex outputs to memory. */
nir_def *is_pervertex_store_thread = nir_imm_true(b);
/* Align the EXEC mask to 8 lanes to overwrite whole 128B blocks on GFX10+, or 4 lanes to
* overwrite whole 64B blocks on GFX9.
*
* GFX6-8 can't align the EXEC mask because it's not ~0.
*/
if (st->gfx_level >= GFX9) {
unsigned align = st->gfx_level >= GFX10 ? 8 : 4;
nir_def *num_tcs_threads = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b), 8, 8);
nir_def *aligned_tcs_threads = nir_align_imm(b, num_tcs_threads, align);
is_pervertex_store_thread = nir_is_subgroup_invocation_lt_amd(b, aligned_tcs_threads);
}
nir_if *if_pervertex_stores = nir_push_if(b, is_pervertex_store_thread);
{
nir_def *hs_ring_tess_offchip = nir_load_ring_tess_offchip_amd(b);
nir_def *offchip_offset = nir_load_ring_tess_offchip_offset_amd(b);
nir_def *local_invocation_index = nir_load_local_invocation_index(b);
nir_def *zero = nir_imm_int(b, 0);
u_foreach_bit64(slot, tcs_vram_per_vtx_out_mask(shader, st)) {
if (!outputs[slot])
continue;
nir_def *vmem_off = hs_per_vertex_output_vmem_offset(b, st, slot, 0, local_invocation_index,
zero, zero);
/* Always store whole vec4s to get cached bandwidth. Non-vec4 stores cause implicit memory loads
* to fill the rest of cache lines with this layout.
*/
nir_store_buffer_amd(b, outputs[slot], hs_ring_tess_offchip, vmem_off, offchip_offset, zero,
.memory_modes = nir_var_shader_out, .access = ACCESS_COHERENT);
}
}
nir_pop_if(b, if_pervertex_stores);
nir_progress(true, impl, nir_metadata_none);
}
@ -1262,6 +1419,8 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info,
{
assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
NIR_PASS(_, shader, nir_io_add_const_offset_to_base, nir_var_shader_out);
lower_tess_io_state state = {
.gfx_level = gfx_level,
.tcs_info = *info,
@ -1286,13 +1445,10 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader, const nir_tcs_info *info,
hs_finale(shader, &state);
/* Cleanup the local variable for tess levels. */
if (state.tcs_info.all_invocations_define_tess_levels) {
NIR_PASS(_, shader, nir_lower_vars_to_ssa);
NIR_PASS(_, shader, nir_remove_dead_variables, nir_var_function_temp, NULL);
NIR_PASS(_, shader, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(_, shader, nir_lower_phis_to_scalar, true);
}
NIR_PASS(_, shader, nir_lower_vars_to_ssa);
NIR_PASS(_, shader, nir_remove_dead_variables, nir_var_function_temp, NULL);
NIR_PASS(_, shader, nir_lower_alu_to_scalar, NULL, NULL);
NIR_PASS(_, shader, nir_lower_phis_to_scalar, true);
return true;
}
@ -1317,7 +1473,9 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
void
ac_nir_compute_tess_wg_info(const struct radeon_info *info, uint64_t outputs_read, uint64_t outputs_written,
uint32_t patch_outputs_read, uint32_t patch_outputs_written, unsigned tcs_vertices_out,
uint32_t patch_outputs_read, uint32_t patch_outputs_written,
uint64_t tcs_cross_invocation_outputs_written,
uint64_t outputs_accessed_indirectly, unsigned tcs_vertices_out,
unsigned wave_size, bool tess_uses_primid, bool all_invocations_define_tess_levels,
unsigned num_tcs_input_cp, unsigned lds_input_vertex_size,
unsigned num_mem_tcs_outputs, unsigned num_mem_tcs_patch_outputs,
@ -1325,7 +1483,8 @@ ac_nir_compute_tess_wg_info(const struct radeon_info *info, uint64_t outputs_rea
{
unsigned num_tcs_output_cp = tcs_vertices_out;
unsigned lds_output_vertex_size =
util_bitcount64(outputs_read & outputs_written & ~TESS_LVL_MASK) * 16;
util_bitcount64((((outputs_read & outputs_written) | tcs_cross_invocation_outputs_written |
outputs_accessed_indirectly) & ~TESS_LVL_MASK)) * 16;
unsigned lds_perpatch_output_patch_size =
(util_bitcount64(all_invocations_define_tess_levels ?
0 : outputs_written & TESS_LVL_MASK) +

View file

@ -3623,8 +3623,12 @@ radv_emit_patch_control_points(struct radv_cmd_buffer *cmd_buffer)
/* These are only used to determine the LDS layout for TCS outputs. */
tcs_info.outputs_read = tcs->info.tcs.tcs_outputs_read;
tcs_info.outputs_written = tcs->info.tcs.tcs_outputs_written;
/* "read" and "written" are OR'd by radv_get_tess_wg_info. */
tcs_info.outputs_read_indirectly = tcs->info.tcs.tcs_outputs_accessed_indirectly;
tcs_info.outputs_written_indirectly = tcs->info.tcs.tcs_outputs_accessed_indirectly;
tcs_info.patch_outputs_read = tcs->info.tcs.tcs_patch_outputs_read;
tcs_info.patch_outputs_written = tcs->info.tcs.tcs_patch_outputs_written;
tcs_info.tess.tcs_cross_invocation_outputs_written = tcs->info.tcs.tcs_cross_invocation_outputs_written;
radv_get_tess_wg_info(pdev, &tcs_info, d->vk.ts.patch_control_points,
/* TODO: This should be only inputs in LDS (not VGPR inputs) to reduce LDS usage */

View file

@ -3609,11 +3609,12 @@ radv_get_tess_wg_info(const struct radv_physical_device *pdev, const struct shad
{
const uint32_t lds_input_vertex_size = get_tcs_input_vertex_stride(tcs_num_lds_inputs);
ac_nir_compute_tess_wg_info(&pdev->info, tcs_info->outputs_read, tcs_info->outputs_written,
tcs_info->patch_outputs_read, tcs_info->patch_outputs_written,
tcs_info->tess.tcs_vertices_out, pdev->ge_wave_size, false,
all_invocations_define_tess_levels, tcs_num_input_vertices, lds_input_vertex_size,
tcs_num_vram_outputs, tcs_num_vram_patch_outputs, num_patches_per_wg, hw_lds_size);
ac_nir_compute_tess_wg_info(
&pdev->info, tcs_info->outputs_read, tcs_info->outputs_written, tcs_info->patch_outputs_read,
tcs_info->patch_outputs_written, tcs_info->tess.tcs_cross_invocation_outputs_written,
tcs_info->outputs_read_indirectly | tcs_info->outputs_written_indirectly, tcs_info->tess.tcs_vertices_out,
pdev->ge_wave_size, false, all_invocations_define_tess_levels, tcs_num_input_vertices, lds_input_vertex_size,
tcs_num_vram_outputs, tcs_num_vram_patch_outputs, num_patches_per_wg, hw_lds_size);
}
VkResult

View file

@ -635,8 +635,10 @@ gather_shader_info_tcs(struct radv_device *device, const nir_shader *nir,
info->tcs.tcs_outputs_read = nir->info.outputs_read;
info->tcs.tcs_outputs_written = nir->info.outputs_written;
info->tcs.tcs_outputs_accessed_indirectly = nir->info.outputs_read_indirectly | nir->info.outputs_written_indirectly;
info->tcs.tcs_patch_outputs_read = nir->info.patch_outputs_read;
info->tcs.tcs_patch_outputs_written = nir->info.patch_outputs_written;
info->tcs.tcs_cross_invocation_outputs_written = nir->info.tess.tcs_cross_invocation_outputs_written;
info->tcs.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
info->tcs.tes_inputs_read = ~0ULL;
info->tcs.tes_patch_inputs_read = ~0ULL;

View file

@ -242,8 +242,10 @@ struct radv_shader_info {
uint64_t tes_patch_inputs_read;
uint64_t tcs_outputs_read;
uint64_t tcs_outputs_written;
uint64_t tcs_outputs_accessed_indirectly;
uint32_t tcs_patch_outputs_read;
uint32_t tcs_patch_outputs_written;
uint64_t tcs_cross_invocation_outputs_written;
unsigned tcs_vertices_out;
uint32_t num_lds_blocks;
uint8_t num_linked_inputs; /* Number of reserved per-vertex input slots in LDS. */

View file

@ -411,6 +411,8 @@ void si_nir_scan_shader(struct si_screen *sscreen, struct nir_shader *nir,
info->base.outputs_written = nir->info.outputs_written;
info->base.patch_outputs_read = nir->info.patch_outputs_read;
info->base.patch_outputs_written = nir->info.patch_outputs_written;
info->base.outputs_read_indirectly = nir->info.outputs_read_indirectly;
info->base.outputs_written_indirectly = nir->info.outputs_written_indirectly;
info->base.num_ubos = nir->info.num_ubos;
info->base.num_ssbos = nir->info.num_ssbos;
@ -440,6 +442,8 @@ void si_nir_scan_shader(struct si_screen *sscreen, struct nir_shader *nir,
info->base.tess.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
info->base.tess.ccw = nir->info.tess.ccw;
info->base.tess.point_mode = nir->info.tess.point_mode;
info->base.tess.tcs_cross_invocation_outputs_read = nir->info.tess.tcs_cross_invocation_outputs_read;
info->base.tess.tcs_cross_invocation_outputs_written = nir->info.tess.tcs_cross_invocation_outputs_written;
break;
case MESA_SHADER_GEOMETRY:

View file

@ -34,6 +34,8 @@ struct si_shader_info {
uint64_t outputs_written;
uint32_t patch_outputs_read;
uint32_t patch_outputs_written;
uint64_t outputs_read_indirectly;
uint64_t outputs_written_indirectly;
uint8_t num_ubos;
uint8_t num_ssbos;
@ -62,6 +64,8 @@ struct si_shader_info {
uint8_t tcs_vertices_out;
bool ccw:1;
bool point_mode:1;
uint64_t tcs_cross_invocation_outputs_read;
uint64_t tcs_cross_invocation_outputs_written;
} tess;
struct {

View file

@ -4820,6 +4820,9 @@ void si_update_tess_io_layout_state(struct si_context *sctx)
ac_nir_compute_tess_wg_info(&sctx->screen->info, tcs->info.base.outputs_read,
tcs->info.base.outputs_written, tcs->info.base.patch_outputs_read,
tcs->info.base.patch_outputs_written,
tcs->info.base.tess.tcs_cross_invocation_outputs_written,
tcs->info.base.outputs_read_indirectly |
tcs->info.base.outputs_written_indirectly,
tcs->info.base.tess.tcs_vertices_out, ls_current->wave_size,
tess_uses_primid, tcs->info.tessfactors_are_def_in_all_invocs,
num_tcs_input_cp, lds_input_vertex_size,