diff --git a/src/compiler/nir/nir_deref.c b/src/compiler/nir/nir_deref.c index cb6847ac9b8..90ad90d6763 100644 --- a/src/compiler/nir/nir_deref.c +++ b/src/compiler/nir/nir_deref.c @@ -851,6 +851,30 @@ nir_deref_instr_fixup_child_types(nir_deref_instr *parent) } } +static bool +opt_alu_of_cast(nir_alu_instr *alu) +{ + bool progress = false; + + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { + assert(alu->src[i].src.is_ssa); + nir_instr *src_instr = alu->src[i].src.ssa->parent_instr; + if (src_instr->type != nir_instr_type_deref) + continue; + + nir_deref_instr *src_deref = nir_instr_as_deref(src_instr); + if (src_deref->deref_type != nir_deref_type_cast) + continue; + + assert(src_deref->parent.is_ssa); + nir_instr_rewrite_src_ssa(&alu->instr, &alu->src[i].src, + src_deref->parent.ssa); + progress = true; + } + + return progress; +} + static bool is_trivial_array_deref_cast(nir_deref_instr *cast) { @@ -1347,6 +1371,13 @@ nir_opt_deref_impl(nir_function_impl *impl) b.cursor = nir_before_instr(instr); switch (instr->type) { + case nir_instr_type_alu: { + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (opt_alu_of_cast(alu)) + progress = true; + break; + } + case nir_instr_type_deref: { nir_deref_instr *deref = nir_instr_as_deref(instr);