diff --git a/src/compiler/nir/nir_opt_intrinsics.c b/src/compiler/nir/nir_opt_intrinsics.c index a277aaf5a29..e1927c6ae45 100644 --- a/src/compiler/nir/nir_opt_intrinsics.c +++ b/src/compiler/nir/nir_opt_intrinsics.c @@ -260,18 +260,20 @@ opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu, } static bool -try_opt_exclusive_scan_to_inclusive(nir_intrinsic_instr *intrin) +try_opt_exclusive_scan_to_inclusive(nir_builder *b, nir_intrinsic_instr *intrin) { if (intrin->def.num_components != 1) return false; + nir_op reduction_op = nir_intrinsic_reduction_op(intrin); + nir_foreach_use_including_if(src, &intrin->def) { if (nir_src_is_if(src) || nir_src_parent_instr(src)->type != nir_instr_type_alu) return false; nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(src)); - if (alu->op != (nir_op)nir_intrinsic_reduction_op(intrin)) + if (alu->op != reduction_op) return false; /* Don't reassociate exact float operations. */ @@ -309,14 +311,16 @@ try_opt_exclusive_scan_to_inclusive(nir_intrinsic_instr *intrin) } /* Convert to inclusive scan. */ - intrin->intrinsic = nir_intrinsic_inclusive_scan; + nir_def *incl_scan = nir_inclusive_scan(b, intrin->src[0].ssa, .reduction_op = reduction_op); nir_foreach_use_including_if_safe(src, &intrin->def) { /* Remove alu. */ nir_alu_instr *alu = nir_instr_as_alu(nir_src_parent_instr(src)); - nir_def_replace(&alu->def, &intrin->def); + nir_def_replace(&alu->def, incl_scan); } + nir_instr_remove(&intrin->instr); + return true; } @@ -374,7 +378,7 @@ opt_intrinsics_intrin(nir_builder *b, nir_intrinsic_instr *intrin, return progress; } case nir_intrinsic_exclusive_scan: - return try_opt_exclusive_scan_to_inclusive(intrin); + return try_opt_exclusive_scan_to_inclusive(b, intrin); default: return false; }