diff --git a/src/compiler/nir/nir_opt_intrinsics.c b/src/compiler/nir/nir_opt_intrinsics.c index fe2a463b03b..1a4e5855a06 100644 --- a/src/compiler/nir/nir_opt_intrinsics.c +++ b/src/compiler/nir/nir_opt_intrinsics.c @@ -240,6 +240,55 @@ opt_intrinsics_alu(nir_builder *b, nir_alu_instr *alu, } } +static bool +try_opt_exclusive_scan_to_inclusive(nir_intrinsic_instr *intrin) +{ + if (intrin->dest.ssa.num_components != 1) + return false; + + nir_foreach_use_including_if(src, &intrin->dest.ssa) { + if (src->is_if || src->parent_instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr); + + if (alu->op != (nir_op)nir_intrinsic_reduction_op(intrin)) + return false; + + /* Don't reassociate exact float operations. */ + if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float && + alu->op != nir_op_fmax && alu->op != nir_op_fmin && alu->exact) + return false; + + if (alu->dest.dest.ssa.num_components != 1) + return false; + + nir_alu_src *alu_src = list_entry(src, nir_alu_src, src); + unsigned src_index = alu_src - alu->src; + + assert(src_index < 2 && nir_op_infos[alu->op].num_inputs == 2); + + nir_ssa_scalar scan_scalar = nir_ssa_scalar_resolved(intrin->src[0].ssa, 0); + nir_ssa_scalar op_scalar = nir_ssa_scalar_resolved(alu->src[!src_index].src.ssa, + alu->src[!src_index].swizzle[0]); + + if (scan_scalar.def != op_scalar.def || scan_scalar.comp != op_scalar.comp) + return false; + } + + /* Convert to inclusive scan. */ + intrin->intrinsic = nir_intrinsic_inclusive_scan; + + nir_foreach_use_including_if_safe(src, &intrin->dest.ssa) { + /* Remove alu. */ + nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr); + nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, &intrin->dest.ssa); + nir_instr_remove(&alu->instr); + } + + return true; +} + static bool opt_intrinsics_intrin(nir_builder *b, nir_intrinsic_instr *intrin, const struct nir_shader_compiler_options *options) @@ -282,7 +331,8 @@ opt_intrinsics_intrin(nir_builder *b, nir_intrinsic_instr *intrin, } return progress; } - + case nir_intrinsic_exclusive_scan: + return try_opt_exclusive_scan_to_inclusive(intrin); default: return false; }