From 6fe0cfdc09d1c8aecc334f1d9f018c8ad5ea46ca Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Thu, 21 Mar 2024 14:52:25 -0400 Subject: [PATCH] zink: vectorize io loads/stores when possible Part-of: --- src/gallium/drivers/zink/zink_compiler.c | 287 +++++++++++++++++++++-- 1 file changed, 271 insertions(+), 16 deletions(-) diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index f3052a1d3ac..443ed41e2d5 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -5675,18 +5675,6 @@ rework_io_vars(nir_shader *nir, nir_variable_mode mode, struct zink_shader *zs) loop_io_var_mask(nir, mode, false, false, mask); } -/* can't scalarize these */ -static bool -skip_scalarize(const nir_instr *instr, const void *data) -{ - if (instr->type != nir_instr_type_intrinsic) - return false; - - nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); - nir_io_semantics sem = nir_intrinsic_io_semantics(intr); - return !sem.fb_fetch_output && sem.num_slots == 1; -} - static int zink_type_size(const struct glsl_type *type, bool bindless) { @@ -5800,6 +5788,274 @@ fix_vertex_input_locations(nir_shader *nir) return nir_shader_intrinsics_pass(nir, fix_vertex_input_locations_instr, nir_metadata_all, NULL); } +struct trivial_revectorize_state { + bool has_xfb; + uint32_t component_mask; + nir_intrinsic_instr *base; + nir_intrinsic_instr *next_emit_vertex; + nir_intrinsic_instr *merge[NIR_MAX_VEC_COMPONENTS]; + struct set *deletions; +}; + +/* always skip xfb; scalarized xfb is preferred */ +static bool +intr_has_xfb(nir_intrinsic_instr *intr) +{ + if (!nir_intrinsic_has_io_xfb(intr)) + return false; + for (unsigned i = 0; i < 2; i++) { + if (nir_intrinsic_io_xfb(intr).out[i].num_components || nir_intrinsic_io_xfb2(intr).out[i].num_components) { + return true; + } + } + return false; +} + +/* helper to avoid vectorizing i/o for different vertices */ +static nir_intrinsic_instr * +find_next_emit_vertex(nir_intrinsic_instr *intr) +{ + bool found = false; + nir_foreach_instr_safe(instr, intr->instr.block) { + if (instr->type == nir_instr_type_intrinsic) { + nir_intrinsic_instr *test_intr = nir_instr_as_intrinsic(instr); + if (!found && test_intr != intr) + continue; + if (!found) { + assert(intr == test_intr); + found = true; + continue; + } + if (test_intr->intrinsic == nir_intrinsic_emit_vertex) + return test_intr; + } + } + return NULL; +} + +/* scan for vectorizable instrs on a given location */ +static bool +trivial_revectorize_intr_scan(nir_shader *nir, nir_intrinsic_instr *intr, struct trivial_revectorize_state *state) +{ + nir_intrinsic_instr *base = state->base; + + if (intr == base) + return false; + + if (intr->intrinsic != base->intrinsic) + return false; + + if (_mesa_set_search(state->deletions, intr)) + return false; + + bool is_load = false; + bool is_input = false; + bool is_interp = false; + filter_io_instr(intr, &is_load, &is_input, &is_interp); + + nir_io_semantics base_sem = nir_intrinsic_io_semantics(base); + nir_io_semantics test_sem = nir_intrinsic_io_semantics(intr); + nir_alu_type base_type = is_load ? nir_intrinsic_dest_type(base) : nir_intrinsic_src_type(base); + nir_alu_type test_type = is_load ? nir_intrinsic_dest_type(intr) : nir_intrinsic_src_type(intr); + int c = nir_intrinsic_component(intr); + /* already detected */ + if (state->component_mask & BITFIELD_BIT(c)) + return false; + /* not a match */ + if (base_sem.location != test_sem.location || base_sem.num_slots != test_sem.num_slots || base_type != test_type) + return false; + /* only vectorize when all srcs match */ + for (unsigned i = !is_input; i < nir_intrinsic_infos[intr->intrinsic].num_srcs; i++) { + if (!nir_srcs_equal(intr->src[i], base->src[i])) + return false; + } + /* never match xfb */ + state->has_xfb |= intr_has_xfb(intr); + if (state->has_xfb) + return false; + if (nir->info.stage == MESA_SHADER_GEOMETRY) { + /* only match same vertex */ + if (state->next_emit_vertex != find_next_emit_vertex(intr)) + return false; + } + uint32_t mask = is_load ? BITFIELD_RANGE(c, intr->num_components) : (nir_intrinsic_write_mask(intr) << c); + state->component_mask |= mask; + u_foreach_bit(component, mask) + state->merge[component] = intr; + + return true; +} + +static bool +trivial_revectorize_scan(struct nir_builder *b, nir_intrinsic_instr *intr, void *data) +{ + bool is_load = false; + bool is_input = false; + bool is_interp = false; + if (!filter_io_instr(intr, &is_load, &is_input, &is_interp)) + return false; + if (intr->num_components != 1) + return false; + nir_io_semantics sem = nir_intrinsic_io_semantics(intr); + if (!is_input || b->shader->info.stage != MESA_SHADER_VERTEX) { + /* always ignore compact arrays */ + switch (sem.location) { + case VARYING_SLOT_CLIP_DIST0: + case VARYING_SLOT_CLIP_DIST1: + case VARYING_SLOT_CULL_DIST0: + case VARYING_SLOT_CULL_DIST1: + case VARYING_SLOT_TESS_LEVEL_INNER: + case VARYING_SLOT_TESS_LEVEL_OUTER: + return false; + default: break; + } + } + /* always ignore to-be-deleted instrs */ + if (_mesa_set_search(data, intr)) + return false; + + /* never vectorize xfb */ + if (intr_has_xfb(intr)) + return false; + + int ic = nir_intrinsic_component(intr); + uint32_t mask = is_load ? BITFIELD_RANGE(ic, intr->num_components) : (nir_intrinsic_write_mask(intr) << ic); + /* already vectorized */ + if (util_bitcount(mask) == 4) + return false; + struct trivial_revectorize_state state = { + .component_mask = mask, + .base = intr, + /* avoid clobbering i/o for different vertices */ + .next_emit_vertex = b->shader->info.stage == MESA_SHADER_GEOMETRY ? find_next_emit_vertex(intr) : NULL, + .deletions = data, + }; + u_foreach_bit(bit, mask) + state.merge[bit] = intr; + bool progress = false; + nir_foreach_instr(instr, intr->instr.block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + nir_intrinsic_instr *test_intr = nir_instr_as_intrinsic(instr); + /* no matching across vertex emission */ + if (test_intr->intrinsic == nir_intrinsic_emit_vertex) + break; + progress |= trivial_revectorize_intr_scan(b->shader, test_intr, &state); + } + if (!progress || state.has_xfb) + return false; + + /* verify nothing crazy happened */ + assert(state.component_mask); + for (unsigned i = 0; i < 4; i++) { + assert(!state.merge[i] || !intr_has_xfb(state.merge[i])); + } + + unsigned first_component = ffs(state.component_mask) - 1; + unsigned num_components = util_bitcount(state.component_mask); + unsigned num_contiguous = 0; + uint32_t contiguous_mask = 0; + for (unsigned i = 0; i < num_components; i++) { + unsigned c = i + first_component; + /* calc mask of contiguous components to vectorize */ + if (state.component_mask & BITFIELD_BIT(c)) { + num_contiguous++; + contiguous_mask |= BITFIELD_BIT(c); + } + /* on the first gap or the the last component, vectorize */ + if (!(state.component_mask & BITFIELD_BIT(c)) || i == num_components - 1) { + if (num_contiguous > 1) { + /* reindex to enable easy src/dest index comparison */ + nir_index_ssa_defs(nir_shader_get_entrypoint(b->shader)); + /* determine the first/last instr to use for the base (vectorized) load/store */ + unsigned first_c = ffs(contiguous_mask) - 1; + nir_intrinsic_instr *base = NULL; + unsigned test_idx = is_load ? UINT32_MAX : 0; + for (unsigned j = 0; j < num_contiguous; j++) { + unsigned merge_c = j + first_c; + nir_intrinsic_instr *merge_intr = state.merge[merge_c]; + /* avoid breaking ssa ordering by using: + * - first instr for vectorized load + * - last instr for vectorized store + * this guarantees all srcs have been seen + */ + if ((is_load && merge_intr->def.index < test_idx) || + (!is_load && merge_intr->src[0].ssa->index >= test_idx)) { + test_idx = is_load ? merge_intr->def.index : merge_intr->src[0].ssa->index; + base = merge_intr; + } + } + assert(base); + /* update instr components */ + nir_intrinsic_set_component(base, nir_intrinsic_component(state.merge[first_c])); + unsigned orig_components = base->num_components; + base->num_components = num_contiguous; + /* do rewrites after loads and before stores */ + b->cursor = is_load ? nir_after_instr(&base->instr) : nir_before_instr(&base->instr); + if (is_load) { + base->def.num_components = num_contiguous; + /* iterate the contiguous loaded components and rewrite merged dests */ + for (unsigned j = 0; j < num_contiguous; j++) { + unsigned merge_c = j + first_c; + nir_intrinsic_instr *merge_intr = state.merge[merge_c]; + /* detect if the merged instr loaded multiple components and use swizzle mask for rewrite */ + unsigned use_components = merge_intr == base ? orig_components : merge_intr->def.num_components; + nir_def *swiz = nir_channels(b, &base->def, BITFIELD_RANGE(j, use_components)); + nir_def_rewrite_uses_after(&merge_intr->def, swiz, merge_intr == base ? swiz->parent_instr : &merge_intr->instr); + j += use_components - 1; + } + } else { + nir_def *comp[NIR_MAX_VEC_COMPONENTS]; + /* generate swizzled vec of store components and rewrite store src */ + for (unsigned j = 0; j < num_contiguous; j++) { + unsigned merge_c = j + first_c; + nir_intrinsic_instr *merge_intr = state.merge[merge_c]; + /* detect if the merged instr stored multiple components and extract them for rewrite */ + unsigned use_components = merge_intr == base ? orig_components : merge_intr->num_components; + for (unsigned k = 0; k < use_components; k++) + comp[j + k] = nir_channel(b, merge_intr->src[0].ssa, k); + j += use_components - 1; + } + nir_def *val = nir_vec(b, comp, num_contiguous); + nir_src_rewrite(&base->src[0], val); + nir_intrinsic_set_write_mask(base, BITFIELD_MASK(num_contiguous)); + } + /* deleting instructions during a foreach explodes the compiler, so delete later */ + for (unsigned j = 0; j < num_contiguous; j++) { + unsigned merge_c = j + first_c; + nir_intrinsic_instr *merge_intr = state.merge[merge_c]; + if (merge_intr != base) + _mesa_set_add(data, &merge_intr->instr); + } + } + contiguous_mask = 0; + num_contiguous = 0; + } + } + + return true; +} + +/* attempt to revectorize scalar i/o, ignoring xfb and "hard stuff" */ +static bool +trivial_revectorize(nir_shader *nir) +{ + struct set deletions; + + if (nir->info.stage > MESA_SHADER_FRAGMENT) + return false; + + _mesa_set_init(&deletions, NULL, _mesa_hash_pointer, _mesa_key_pointer_equal); + bool progress = nir_shader_intrinsics_pass(nir, trivial_revectorize_scan, nir_metadata_dominance, &deletions); + /* now it's safe to delete */ + set_foreach_remove(&deletions, entry) { + nir_instr *instr = (void*)entry->key; + nir_instr_remove(instr); + } + ralloc_free(deletions.table); + return progress; +} + struct zink_shader * zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) { @@ -5855,10 +6111,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) NIR_PASS_V(nir, nir_lower_alu_vec8_16_srcs); } - nir_variable_mode scalarize = nir_var_shader_in; - if (nir->info.stage != MESA_SHADER_FRAGMENT) - scalarize |= nir_var_shader_out; - NIR_PASS_V(nir, nir_lower_io_to_scalar, scalarize, skip_scalarize, NULL); + NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_shader_out, NULL, NULL); optimize_nir(nir, NULL, true); nir_foreach_variable_with_modes(var, nir, nir_var_shader_in | nir_var_shader_out) { if (glsl_type_is_image(var->type) || glsl_type_is_sampler(var->type)) { @@ -5871,6 +6124,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) NIR_PASS_V(nir, fix_vertex_input_locations); nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir)); scan_nir(screen, nir, ret); + NIR_PASS_V(nir, nir_opt_vectorize, NULL, NULL); + NIR_PASS_V(nir, trivial_revectorize); if (nir->info.io_lowered) { rework_io_vars(nir, nir_var_shader_in, ret); rework_io_vars(nir, nir_var_shader_out, ret);