diff --git a/src/compiler/nir/nir_opt_shrink_vectors.c b/src/compiler/nir/nir_opt_shrink_vectors.c index 9cb186f6c25..3a77a200f77 100644 --- a/src/compiler/nir/nir_opt_shrink_vectors.c +++ b/src/compiler/nir/nir_opt_shrink_vectors.c @@ -70,6 +70,31 @@ shrink_dest_to_read_mask(nir_ssa_def *def) return false; } +static void +reswizzle_alu_uses(nir_ssa_def *def, uint8_t *reswizzle) +{ + nir_foreach_use(use_src, def) { + /* all uses must be ALU instructions */ + assert(use_src->parent_instr->type == nir_instr_type_alu); + nir_alu_src *alu_src = (nir_alu_src*)use_src; + + /* reswizzle ALU sources */ + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) + alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]]; + } +} + +static bool +is_only_used_by_alu(nir_ssa_def *def) +{ + nir_foreach_use(use_src, def) { + if (use_src->parent_instr->type != nir_instr_type_alu) + return false; + } + + return true; +} + static bool opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr) { @@ -93,11 +118,9 @@ opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr) break; } - /* don't remove any channels if used by an intrinsic */ - nir_foreach_use(use_src, def) { - if (use_src->parent_instr->type == nir_instr_type_intrinsic) - return false; - } + /* don't remove any channels if used by non-ALU */ + if (!is_only_used_by_alu(def)) + return false; unsigned mask = nir_ssa_def_components_read(def); unsigned last_bit = util_last_bit(mask); @@ -156,12 +179,7 @@ opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr) assert(index == num_components); /* update uses */ - nir_foreach_use(use_src, def) { - assert(use_src->parent_instr->type == nir_instr_type_alu); - nir_alu_src *alu_src = (nir_alu_src*)use_src; - for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) - alu_src->swizzle[i] = reswizzle[alu_src->swizzle[i]]; - } + reswizzle_alu_uses(def, reswizzle); return true; } @@ -204,7 +222,51 @@ opt_shrink_vectors_intrinsic(nir_builder *b, nir_intrinsic_instr *instr) static bool opt_shrink_vectors_load_const(nir_load_const_instr *instr) { - return shrink_dest_to_read_mask(&instr->def); + nir_ssa_def *def = &instr->def; + + /* early out if there's nothing to do. */ + if (def->num_components == 1) + return false; + + /* don't remove any channels if used by non-ALU */ + if (!is_only_used_by_alu(def)) + return false; + + unsigned mask = nir_ssa_def_components_read(def); + + /* If nothing was read, leave it up to DCE. */ + if (!mask) + return false; + + uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; + unsigned num_components = 0; + for (unsigned i = 0; i < def->num_components; i++) { + if (!((mask >> i) & 0x1)) + continue; + + /* Try reuse a component with the same constant */ + unsigned j; + for (j = 0; j < num_components; j++) { + if (instr->value[i].u64 == instr->value[j].u64) { + reswizzle[i] = j; + break; + } + } + + /* Otherwise, just append the value */ + if (j == num_components) { + instr->value[num_components] = instr->value[i]; + reswizzle[i] = num_components++; + } + } + + if (num_components == def->num_components) + return false; + + def->num_components = num_components; + reswizzle_alu_uses(def, reswizzle); + + return true; } static bool