diff --git a/src/compiler/nir/nir_opt_idiv_const.c b/src/compiler/nir/nir_opt_idiv_const.c index f94efc55991..20e3d552575 100644 --- a/src/compiler/nir/nir_opt_idiv_const.c +++ b/src/compiler/nir/nir_opt_idiv_const.c @@ -100,6 +100,53 @@ build_idiv(nir_builder *b, nir_ssa_def *n, int64_t d) } } +static nir_ssa_def * +build_irem(nir_builder *b, nir_ssa_def *n, int64_t d) +{ + int64_t int_min = u_intN_min(n->bit_size); + if (d == 0) { + return nir_imm_intN_t(b, 0, n->bit_size); + } else if (d == int_min) { + return nir_bcsel(b, nir_ieq_imm(b, n, int_min), nir_imm_intN_t(b, 0, n->bit_size), n); + } else { + d = d < 0 ? -d : d; + if (util_is_power_of_two_or_zero64(d)) { + nir_ssa_def *tmp = nir_bcsel(b, nir_ilt(b, n, nir_imm_intN_t(b, 0, n->bit_size)), + nir_iadd_imm(b, n, d - 1), n); + return nir_isub(b, n, nir_iand_imm(b, tmp, -d)); + } else { + return nir_isub(b, n, nir_imul(b, build_idiv(b, n, d), + nir_imm_intN_t(b, d, n->bit_size))); + } + } +} + +static nir_ssa_def * +build_imod(nir_builder *b, nir_ssa_def *n, int64_t d) +{ + int64_t int_min = u_intN_min(n->bit_size); + if (d == 0) { + return nir_imm_intN_t(b, 0, n->bit_size); + } else if (d == int_min) { + nir_ssa_def *int_min_def = nir_imm_intN_t(b, int_min, n->bit_size); + nir_ssa_def *is_neg_not_int_min = nir_ult(b, int_min_def, n); + nir_ssa_def *is_zero = nir_ieq_imm(b, n, 0); + return nir_bcsel(b, nir_ior(b, is_neg_not_int_min, is_zero), n, nir_iadd(b, int_min_def, n)); + } else if (d > 0 && util_is_power_of_two_or_zero64(d)) { + return nir_iand(b, n, nir_imm_intN_t(b, d - 1, n->bit_size)); + } else if (d < 0 && util_is_power_of_two_or_zero64(-d)) { + nir_ssa_def *d_def = nir_imm_intN_t(b, d, n->bit_size); + nir_ssa_def *res = nir_ior(b, n, d_def); + return nir_bcsel(b, nir_ieq(b, res, d_def), nir_imm_intN_t(b, 0, n->bit_size), res); + } else { + nir_ssa_def *rem = build_irem(b, n, d); + nir_ssa_def *zero = nir_imm_intN_t(b, 0, n->bit_size); + nir_ssa_def *sign_same = d < 0 ? nir_ilt(b, n, zero) : nir_ige(b, n, zero); + nir_ssa_def *rem_zero = nir_ieq(b, rem, zero); + return nir_bcsel(b, nir_ior(b, rem_zero, sign_same), rem, nir_iadd_imm(b, rem, d)); + } +} + static bool nir_opt_idiv_const_instr(nir_builder *b, nir_alu_instr *alu) { @@ -143,6 +190,12 @@ nir_opt_idiv_const_instr(nir_builder *b, nir_alu_instr *alu) case nir_op_umod: q[comp] = build_umod(b, n, d); break; + case nir_op_imod: + q[comp] = build_imod(b, n, d); + break; + case nir_op_irem: + q[comp] = build_irem(b, n, d); + break; default: unreachable("Unknown integer division op"); } @@ -171,7 +224,9 @@ nir_opt_idiv_const_impl(nir_function_impl *impl, unsigned min_bit_size) nir_alu_instr *alu = nir_instr_as_alu(instr); if (alu->op != nir_op_udiv && alu->op != nir_op_idiv && - alu->op != nir_op_umod) + alu->op != nir_op_umod && + alu->op != nir_op_imod && + alu->op != nir_op_irem) continue; assert(alu->dest.dest.is_ssa);