From 2514999c9c5d4b64b007304e46f0a41d50e71cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ol=C5=A1=C3=A1k?= Date: Sun, 26 May 2024 23:02:42 -0400 Subject: [PATCH] nir: add nir_opt_vectorize_io, vectorizing lowered IO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since nir_opt_varyings requires scalar IO and thus all drivers have to scalarize it, this gives the option to re-vectorize IO after that. Reviewed-by: Timur Kristóf Part-of: --- src/compiler/nir/meson.build | 1 + src/compiler/nir/nir.h | 1 + src/compiler/nir/nir_opt_vectorize_io.c | 558 ++++++++++++++++++++++++ 3 files changed, 560 insertions(+) create mode 100644 src/compiler/nir/nir_opt_vectorize_io.c diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index 5a130724ca8..f340b19d502 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -276,6 +276,7 @@ files_libnir = files( 'nir_opt_uniform_subgroup.c', 'nir_opt_varyings.c', 'nir_opt_vectorize.c', + 'nir_opt_vectorize_io.c', 'nir_passthrough_gs.c', 'nir_passthrough_tcs.c', 'nir_phi_builder.c', diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index aad0e3e8282..3036cd76a05 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6636,6 +6636,7 @@ bool nir_opt_uniform_subgroup(nir_shader *shader, bool nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter, void *data); +bool nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes); bool nir_opt_conditional_discard(nir_shader *shader); bool nir_opt_move_discards_to_top(nir_shader *shader); diff --git a/src/compiler/nir/nir_opt_vectorize_io.c b/src/compiler/nir/nir_opt_vectorize_io.c new file mode 100644 index 00000000000..5306f2112ae --- /dev/null +++ b/src/compiler/nir/nir_opt_vectorize_io.c @@ -0,0 +1,558 @@ +/* + * Copyright 2024 Advanced Micro Devices, Inc. + * + * SPDX-License-Identifier: MIT + */ + +/** + * This pass: + * - vectorizes lowered input/output loads and stores + * - vectorizes low and high 16-bit loads and stores by merging them into + * a single 32-bit load or store (except load_interpolated_input, which has + * to keep bit_size=16) + * - performs DCE of output stores that overwrite the previous value by writing + * into the same slot and component. + * + * Vectorization is only local within basic blocks. No vectorization occurs + * across basic block boundaries, barriers (only TCS outputs), emits (only + * GS outputs), and output load <-> output store dependencies. + * + * All loads and stores must be scalar. 64-bit loads and stores are forbidden. + * + * For each basic block, the time complexity is O(n*log(n)) where n is + * the number of IO instructions within that block. + */ + +#include "nir.h" +#include "nir_builder.h" +#include "util/u_dynarray.h" + +/* Return 0 if loads/stores are vectorizable. Return 1 or -1 to define + * an ordering between non-vectorizable instructions. This is used by qsort, + * to sort all gathered instructions into groups of vectorizable instructions. + */ +static int +compare_is_not_vectorizable(nir_intrinsic_instr *a, nir_intrinsic_instr *b) +{ + if (a->intrinsic != b->intrinsic) + return a->intrinsic > b->intrinsic ? 1 : -1; + + nir_src *offset0 = nir_get_io_offset_src(a); + nir_src *offset1 = nir_get_io_offset_src(b); + if (offset0 && offset0->ssa != offset1->ssa) + return offset0->ssa->index > offset1->ssa->index ? 1 : -1; + + nir_src *array_idx0 = nir_get_io_arrayed_index_src(a); + nir_src *array_idx1 = nir_get_io_arrayed_index_src(b); + if (array_idx0 && array_idx0->ssa != array_idx1->ssa) + return array_idx0->ssa->index > array_idx1->ssa->index ? 1 : -1; + + /* Compare barycentrics or vertex index. */ + if ((a->intrinsic == nir_intrinsic_load_interpolated_input || + a->intrinsic == nir_intrinsic_load_input_vertex) && + a->src[0].ssa != b->src[0].ssa) + return a->src[0].ssa->index > b->src[0].ssa->index ? 1 : -1; + + nir_io_semantics sem0 = nir_intrinsic_io_semantics(a); + nir_io_semantics sem1 = nir_intrinsic_io_semantics(b); + if (sem0.location != sem1.location) + return sem0.location > sem1.location ? 1 : -1; + + /* The mediump flag isn't mergable. */ + if (sem0.medium_precision != sem1.medium_precision) + return sem0.medium_precision > sem1.medium_precision ? 1 : -1; + + /* Don't merge per-view attributes with non-per-view attributes. */ + if (sem0.per_view != sem1.per_view) + return sem0.per_view > sem1.per_view ? 1 : -1; + + if (sem0.interp_explicit_strict != sem1.interp_explicit_strict) + return sem0.interp_explicit_strict > sem1.interp_explicit_strict ? 1 : -1; + + if (sem0.per_primitive != sem1.per_primitive) + return sem0.per_primitive > sem1.per_primitive ? 1 : -1; + + /* Only load_interpolated_input can't merge low and high halves of 16-bit + * loads/stores. + */ + if (a->intrinsic == nir_intrinsic_load_interpolated_input && + sem0.high_16bits != sem1.high_16bits) + return sem0.high_16bits > sem1.high_16bits ? 1 : -1; + + return 0; +} + +static int +compare_intr(const void *xa, const void *xb) +{ + nir_intrinsic_instr *a = *(nir_intrinsic_instr **)xa; + nir_intrinsic_instr *b = *(nir_intrinsic_instr **)xb; + + int comp = compare_is_not_vectorizable(a, b); + if (comp) + return comp; + + /* qsort isn't stable. This ensures that later stores aren't moved before earlier stores. */ + return a->instr.index > b->instr.index ? 1 : -1; +} + +static void +vectorize_load(nir_intrinsic_instr *chan[8], unsigned start, unsigned count, + bool merge_low_high_16_to_32) +{ + nir_intrinsic_instr *first = NULL; + + /* Find the first instruction where the vectorized load will be + * inserted. + */ + for (unsigned i = start; i < start + count; i++) { + first = !first || chan[i]->instr.index < first->instr.index ? + chan[i] : first; + if (merge_low_high_16_to_32) { + first = !first || chan[4 + i]->instr.index < first->instr.index ? + chan[4 + i] : first; + } + } + + /* Insert the vectorized load. */ + nir_builder b = nir_builder_at(nir_before_instr(&first->instr)); + nir_intrinsic_instr *new_intr = + nir_intrinsic_instr_create(b.shader, first->intrinsic); + + new_intr->num_components = count; + nir_def_init(&new_intr->instr, &new_intr->def, count, + merge_low_high_16_to_32 ? 32 : first->def.bit_size); + memcpy(new_intr->src, first->src, + nir_intrinsic_infos[first->intrinsic].num_srcs * sizeof(nir_src)); + nir_intrinsic_copy_const_indices(new_intr, first); + nir_intrinsic_set_component(new_intr, start); + + if (merge_low_high_16_to_32) { + nir_io_semantics sem = nir_intrinsic_io_semantics(new_intr); + sem.high_16bits = 0; + nir_intrinsic_set_io_semantics(new_intr, sem); + nir_intrinsic_set_dest_type(new_intr, + (nir_intrinsic_dest_type(new_intr) & ~16) | 32); + } + + nir_builder_instr_insert(&b, &new_intr->instr); + nir_def *def = &new_intr->def; + + /* Replace the scalar loads. */ + if (merge_low_high_16_to_32) { + for (unsigned i = start; i < start + count; i++) { + nir_def *comp = nir_channel(&b, def, i - start); + + nir_def_rewrite_uses(&chan[i]->def, + nir_unpack_32_2x16_split_x(&b, comp)); + nir_def_rewrite_uses(&chan[4 + i]->def, + nir_unpack_32_2x16_split_y(&b, comp)); + nir_instr_remove(&chan[i]->instr); + nir_instr_remove(&chan[4 + i]->instr); + } + } else { + for (unsigned i = start; i < start + count; i++) { + nir_def_rewrite_uses(&chan[i]->def, nir_channel(&b, def, i - start)); + nir_instr_remove(&chan[i]->instr); + } + } +} + +static void +vectorize_store(nir_intrinsic_instr *chan[8], unsigned start, unsigned count, + bool merge_low_high_16_to_32) +{ + nir_intrinsic_instr *last = NULL; + + /* Find the last instruction where the vectorized store will be + * inserted. + */ + for (unsigned i = start; i < start + count; i++) { + last = !last || chan[i]->instr.index > last->instr.index ? + chan[i] : last; + if (merge_low_high_16_to_32) { + last = !last || chan[4 + i]->instr.index > last->instr.index ? + chan[4 + i] : last; + } + } + + /* Change the last instruction to a vectorized store. Update xfb first + * because we need to read some info from "last" before overwriting it. + */ + if (nir_intrinsic_has_io_xfb(last)) { + nir_io_xfb xfb[2] = {{{{0}}}}; + + for (unsigned i = start; i < start + count; i++) { + xfb[i / 2].out[i % 2] = + (i < 2 ? nir_intrinsic_io_xfb(chan[i]) : + nir_intrinsic_io_xfb2(chan[i])).out[i % 2]; + + /* Merging low and high 16 bits to 32 bits is not possible + * with xfb in some cases. (and it's not implemented for + * cases where it's possible) + */ + assert(!xfb[i / 2].out[i % 2].num_components || + !merge_low_high_16_to_32); + } + + /* Now vectorize xfb info by merging the individual elements. */ + for (unsigned i = start; i < start + count; i++) { + /* mediump means that xfb upconverts to 32 bits when writing to + * memory. + */ + unsigned xfb_comp_size = + nir_intrinsic_io_semantics(chan[i]).medium_precision ? + 32 : chan[i]->src[0].ssa->bit_size; + + for (unsigned j = i + 1; j < start + count; j++) { + if (xfb[i / 2].out[i % 2].buffer != xfb[j / 2].out[j % 2].buffer || + xfb[i / 2].out[i % 2].offset != xfb[j / 2].out[j % 2].offset + + xfb_comp_size * (j - i)) + break; + + xfb[i / 2].out[i % 2].num_components++; + memset(&xfb[j / 2].out[j % 2], 0, sizeof(xfb[j / 2].out[j % 2])); + } + } + + nir_intrinsic_set_io_xfb(last, xfb[0]); + nir_intrinsic_set_io_xfb2(last, xfb[1]); + } + + /* Update gs_streams. */ + unsigned gs_streams = 0; + for (unsigned i = start; i < start + count; i++) { + gs_streams |= (nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) << + ((i - start) * 2); + } + + nir_io_semantics sem = nir_intrinsic_io_semantics(last); + sem.gs_streams = gs_streams; + + /* Update other flags. */ + for (unsigned i = start; i < start + count; i++) { + if (!nir_intrinsic_io_semantics(chan[i]).no_sysval_output) + sem.no_sysval_output = 0; + if (!nir_intrinsic_io_semantics(chan[i]).no_varying) + sem.no_varying = 0; + if (nir_intrinsic_io_semantics(chan[i]).invariant) + sem.invariant = 1; + } + + if (merge_low_high_16_to_32) { + /* Update "no" flags for high bits. */ + for (unsigned i = start; i < start + count; i++) { + if (!nir_intrinsic_io_semantics(chan[4 + i]).no_sysval_output) + sem.no_sysval_output = 0; + if (!nir_intrinsic_io_semantics(chan[4 + i]).no_varying) + sem.no_varying = 0; + if (nir_intrinsic_io_semantics(chan[4 + i]).invariant) + sem.invariant = 1; + } + + /* Update the type. */ + sem.high_16bits = 0; + nir_intrinsic_set_src_type(last, + (nir_intrinsic_src_type(last) & ~16) | 32); + } + + /* TODO: Merge names? */ + + /* Update the rest. */ + nir_intrinsic_set_io_semantics(last, sem); + nir_intrinsic_set_component(last, start); + nir_intrinsic_set_write_mask(last, BITFIELD_MASK(count)); + last->num_components = count; + + nir_builder b = nir_builder_at(nir_before_instr(&last->instr)); + + /* Replace the stored scalar with the vector. */ + if (merge_low_high_16_to_32) { + nir_def *value[4]; + for (unsigned i = start; i < start + count; i++) { + value[i] = nir_pack_32_2x16_split(&b, chan[i]->src[0].ssa, + chan[4 + i]->src[0].ssa); + } + + nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count)); + } else { + nir_def *value[4]; + for (unsigned i = start; i < start + count; i++) + value[i] = chan[i]->src[0].ssa; + + nir_src_rewrite(&last->src[0], nir_vec(&b, &value[start], count)); + } + + /* Remove the scalar stores. */ + for (unsigned i = start; i < start + count; i++) { + if (chan[i] != last) + nir_instr_remove(&chan[i]->instr); + if (merge_low_high_16_to_32 && chan[4 + i] != last) + nir_instr_remove(&chan[4 + i]->instr); + } +} + +/* Vectorize a vector of scalar instructions. chan[8] are the channels. + * (the last 4 are the high 16-bit channels) + */ +static bool +vectorize_slot(nir_intrinsic_instr *chan[8], unsigned mask) +{ + bool progress = false; + + /* First, merge low and high 16-bit halves into 32 bits separately when + * possible. Then vectorize what's left. + */ + for (int merge_low_high_16_to_32 = 1; merge_low_high_16_to_32 >= 0; + merge_low_high_16_to_32--) { + unsigned scan_mask; + + if (merge_low_high_16_to_32) { + /* Get the subset of the mask where both low and high bits are set. */ + scan_mask = 0; + for (unsigned i = 0; i < 4; i++) { + unsigned low_high_bits = BITFIELD_BIT(i) | BITFIELD_BIT(i + 4); + + if ((mask & low_high_bits) == low_high_bits) { + /* Merging low and high 16 bits to 32 bits is not possible + * with xfb in some cases. (and it's not implemented for + * cases where it's possible) + */ + if (nir_intrinsic_has_io_xfb(chan[i])) { + unsigned hi = i + 4; + + if ((i < 2 ? nir_intrinsic_io_xfb(chan[i]) + : nir_intrinsic_io_xfb2(chan[i])).out[i % 2].num_components || + (i < 2 ? nir_intrinsic_io_xfb(chan[hi]) + : nir_intrinsic_io_xfb2(chan[hi])).out[i % 2].num_components) + continue; + } + + /* The GS stream must be the same for both halves. */ + if ((nir_intrinsic_io_semantics(chan[i]).gs_streams & 0x3) != + (nir_intrinsic_io_semantics(chan[4 + i]).gs_streams & 0x3)) + continue; + + scan_mask |= BITFIELD_BIT(i); + mask &= ~low_high_bits; + } + } + } else { + scan_mask = mask; + } + + while (scan_mask) { + int start, count; + + u_bit_scan_consecutive_range(&scan_mask, &start, &count); + + if (count == 1 && !merge_low_high_16_to_32) + continue; /* There is nothing to vectorize. */ + + bool is_load = nir_intrinsic_infos[chan[start]->intrinsic].has_dest; + + if (is_load) + vectorize_load(chan, start, count, merge_low_high_16_to_32); + else + vectorize_store(chan, start, count, merge_low_high_16_to_32); + + progress = true; + } + } + + return progress; +} + +static bool +vectorize_batch(struct util_dynarray *io_instructions) +{ + unsigned num_instr = util_dynarray_num_elements(io_instructions, void *); + + /* We need to at least 2 instructions to have something to do. */ + if (num_instr <= 1) { + /* Clear the array. The next block will reuse it. */ + util_dynarray_clear(io_instructions); + return false; + } + + /* The instructions are sorted such that groups of vectorizable + * instructions are next to each other. Multiple incompatible + * groups of vectorizable instructions can occur in this array. + * The reason why 2 groups would be incompatible is that they + * could have a different intrinsic, indirect index, array index, + * vertex index, barycentrics, or location. Each group is vectorized + * separately. + * + * This reorders instructions in the array, but not in the shader. + */ + qsort(io_instructions->data, num_instr, sizeof(void*), compare_intr); + + nir_intrinsic_instr *chan[8] = {0}, *prev = NULL; + unsigned chan_mask = 0; + bool progress = false; + + /* Vectorize all groups. + * + * The channels for each group are gathered. If 2 stores overwrite + * the same channel, the earlier store is DCE'd here. + */ + util_dynarray_foreach(io_instructions, nir_intrinsic_instr *, intr) { + /* If the next instruction is not vectorizable, vectorize what + * we have gathered so far. + */ + if (prev && compare_is_not_vectorizable(prev, *intr)) { + /* We need at least 2 instructions to have something to do. */ + if (util_bitcount(chan_mask) > 1) + progress |= vectorize_slot(chan, chan_mask); + + prev = NULL; + memset(chan, 0, sizeof(chan)); + chan_mask = 0; + } + + /* This performs DCE of output stores because the previous value + * is being overwritten. + */ + unsigned index = nir_intrinsic_io_semantics(*intr).high_16bits * 4 + + nir_intrinsic_component(*intr); + bool is_store = !nir_intrinsic_infos[(*intr)->intrinsic].has_dest; + if (is_store && chan[index]) + nir_instr_remove(&chan[index]->instr); + + /* Gather the channel. */ + chan[index] = *intr; + prev = *intr; + chan_mask |= BITFIELD_BIT(index); + } + + /* Vectorize the last group. */ + if (prev && util_bitcount(chan_mask) > 1) + progress |= vectorize_slot(chan, chan_mask); + + /* Clear the array. The next block will reuse it. */ + util_dynarray_clear(io_instructions); + return progress; +} + +bool +nir_opt_vectorize_io(nir_shader *shader, nir_variable_mode modes) +{ + assert(!(modes & ~(nir_var_shader_in | nir_var_shader_out))); + + if ((shader->info.stage == MESA_SHADER_TESS_CTRL || + shader->info.stage == MESA_SHADER_GEOMETRY) && + util_bitcount(modes) == 2) { + /* When vectorizing TCS and GS IO, inputs can ignore barriers and emits, + * but that is only done when outputs are ignored, so vectorize them + * separately. + */ + return nir_opt_vectorize_io(shader, nir_var_shader_in) || + nir_opt_vectorize_io(shader, nir_var_shader_out); + } + + /* Initialize dynamic arrays. */ + struct util_dynarray io_instructions; + util_dynarray_init(&io_instructions, NULL); + bool global_progress = false; + + nir_foreach_function_impl(impl, shader) { + bool progress = false; + nir_metadata_require(impl, nir_metadata_instr_index); + + nir_foreach_block(block, impl) { + BITSET_DECLARE(has_output_loads, NUM_TOTAL_VARYING_SLOTS * 8); + BITSET_DECLARE(has_output_stores, NUM_TOTAL_VARYING_SLOTS * 8); + BITSET_ZERO(has_output_loads); + BITSET_ZERO(has_output_stores); + + /* Gather load/store intrinsics within the block. */ + nir_foreach_instr(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + bool is_load = nir_intrinsic_infos[intr->intrinsic].has_dest; + bool is_output = false; + nir_io_semantics sem = {0}; + unsigned index = 0; + + if (nir_intrinsic_has_io_semantics(intr)) { + sem = nir_intrinsic_io_semantics(intr); + assert(sem.location < NUM_TOTAL_VARYING_SLOTS); + index = sem.location * 8 + sem.high_16bits * 4 + + nir_intrinsic_component(intr); + } + + switch (intr->intrinsic) { + case nir_intrinsic_load_input: + case nir_intrinsic_load_input_vertex: + case nir_intrinsic_load_interpolated_input: + case nir_intrinsic_load_per_vertex_input: + if (!(modes & nir_var_shader_in)) + continue; + break; + + case nir_intrinsic_load_output: + case nir_intrinsic_load_per_vertex_output: + case nir_intrinsic_load_per_primitive_output: + case nir_intrinsic_store_output: + case nir_intrinsic_store_per_vertex_output: + case nir_intrinsic_store_per_primitive_output: + if (!(modes & nir_var_shader_out)) + continue; + + /* Break the batch if an output load is followed by an output + * store to the same channel and vice versa. + */ + if (BITSET_TEST(is_load ? has_output_stores : has_output_loads, + index)) { + progress |= vectorize_batch(&io_instructions); + BITSET_ZERO(has_output_loads); + BITSET_ZERO(has_output_stores); + } + is_output = true; + break; + + case nir_intrinsic_barrier: + /* Don't vectorize across TCS barriers. */ + if (modes & nir_var_shader_out && + nir_intrinsic_memory_modes(intr) & nir_var_shader_out) { + progress |= vectorize_batch(&io_instructions); + BITSET_ZERO(has_output_loads); + BITSET_ZERO(has_output_stores); + } + continue; + + case nir_intrinsic_emit_vertex: + /* Don't vectorize across GS emits. */ + progress |= vectorize_batch(&io_instructions); + BITSET_ZERO(has_output_loads); + BITSET_ZERO(has_output_stores); + continue; + + default: + continue; + } + + /* Only scalar 16 and 32-bit instructions are allowed. */ + ASSERTED nir_def *value = is_load ? &intr->def : intr->src[0].ssa; + assert(value->num_components == 1); + assert(value->bit_size == 16 || value->bit_size == 32); + + util_dynarray_append(&io_instructions, void *, intr); + if (is_output) + BITSET_SET(is_load ? has_output_loads : has_output_stores, index); + } + + progress |= vectorize_batch(&io_instructions); + } + + nir_metadata_preserve(impl, progress ? (nir_metadata_block_index | + nir_metadata_dominance) : + nir_metadata_all); + global_progress |= progress; + } + util_dynarray_fini(&io_instructions); + + return global_progress; +}