nir: allow loops with unknown induction var initialiser to unroll

If the condition of the loop terminator is based on an unsigned value we
can in some cases find the max number of possible loop trips. With the
max loop trips know a complex unroll can unroll the loop.

For example:

   uniform uint x;
   uint i = x;
   while (true) {
      if (i >= 4)
         break;

      i += 6;
   }

The above loop can be unrolled even though we don't know the initial
value of the induction variable because it can have at most 1 iteration.

There were no changes with my shader-db collection. Change was inspired
by MR #31312 where builtin shader code failed to unroll.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31701>
This commit is contained in:
Timothy Arceri 2024-10-11 15:20:04 +11:00
parent fcaf0f2590
commit 6ca81adffc
3 changed files with 158 additions and 16 deletions

View file

@ -1135,17 +1135,11 @@ get_induction_and_limit_vars(nir_scalar cond,
nir_loop_variable *src1_lv = get_loop_var(rhs.def, state);
if (src0_lv->type == basic_induction) {
if (!nir_src_is_const(*src0_lv->init_src))
return false;
*ind = lhs;
*limit = rhs;
*limit_rhs = true;
return true;
} else if (src1_lv->type == basic_induction) {
if (!nir_src_is_const(*src1_lv->init_src))
return false;
*ind = rhs;
*limit = lhs;
*limit_rhs = false;
@ -1331,20 +1325,56 @@ find_trip_count(loop_info_state *state, unsigned execution_mode,
lv->update_src->swizzle[basic_ind.comp]
};
nir_alu_instr *step_alu =
nir_instr_as_alu(nir_src_parent_instr(&lv->update_src->src));
/* If the comparision is of unsigned type we don't necessarily need to
* know the initial value to be able to calculate the max number of
* iterations
*/
bool can_find_max_trip_count = step_alu->op == nir_op_iadd &&
((alu_op == nir_op_uge && !invert_cond && limit_rhs) ||
(alu_op == nir_op_ult && !invert_cond && !limit_rhs));
/* nir_op_isub should have been lowered away by this point */
assert(step_alu->op != nir_op_isub);
/* For nir_op_uge as alu_op, the induction variable is [0,limit). For
* nir_op_ult, it's [0,limit]. It must always be step_val larger in the
* next iteration to use the can_find_max_trip_count=true path. This
* check ensures that no unsigned overflow happens.
* TODO: support for overflow could be added if a non-zero initial_val
* is chosen.
*/
if (can_find_max_trip_count && nir_scalar_is_const(alu_s)) {
uint64_t uint_max = u_uintN_max(alu_s.def->bit_size);
uint64_t max_step_val =
uint_max - nir_const_value_as_uint(limit_val, alu_s.def->bit_size) +
(alu_op == nir_op_uge ? 1 : 0);
can_find_max_trip_count &= nir_scalar_as_uint(alu_s) <= max_step_val;
}
/* We are not guaranteed by that at one of these sources is a constant.
* Try to find one.
*/
if (!nir_scalar_is_const(initial_s) ||
if ((!nir_scalar_is_const(initial_s) && !can_find_max_trip_count) ||
!nir_scalar_is_const(alu_s))
continue;
nir_const_value initial_val = nir_scalar_as_const_value(initial_s);
nir_const_value initial_val;
if (nir_scalar_is_const(initial_s))
initial_val = nir_scalar_as_const_value(initial_s);
else {
trip_count_known = false;
terminator->exact_trip_count_unknown = true;
initial_val = nir_const_value_for_uint(0, 32);
assert(can_find_max_trip_count);
}
nir_const_value step_val = nir_scalar_as_const_value(alu_s);
int iterations = calculate_iterations(nir_get_scalar(lv->basis, basic_ind.comp), limit,
initial_val, step_val, limit_val,
nir_instr_as_alu(nir_src_parent_instr(&lv->update_src->src)),
cond,
step_alu, cond,
alu_op, limit_rhs,
invert_cond,
execution_mode,

View file

@ -60,6 +60,8 @@ struct loop_builder_param {
nir_def *(*incr_instr)(nir_builder *,
nir_def *,
nir_def *);
bool use_unknown_init_value;
bool invert_exit_condition_and_continue_branch;
};
static nir_loop *
@ -75,7 +77,14 @@ loop_builder(nir_builder *b, loop_builder_param p)
* i = incr_instr(i, incr_value);
* }
*/
nir_def *ssa_0 = nir_imm_int(b, p.init_value);
nir_def *ssa_0;
if (p.use_unknown_init_value) {
nir_def *one = nir_imm_int(b, 1);
nir_def *twelve = nir_imm_int(b, 12);
ssa_0 = nir_load_ubo(b, 1, 32, one, twelve, (gl_access_qualifier)0, 0, 0, 0, 16);
} else
ssa_0 = nir_imm_int(b, p.init_value);
nir_def *ssa_1 = nir_imm_int(b, p.cond_value);
nir_def *ssa_2 = nir_imm_int(b, p.incr_value);
@ -91,8 +100,14 @@ loop_builder(nir_builder *b, loop_builder_param p)
nir_def *ssa_5 = &phi->def;
nir_def *ssa_3 = p.cond_instr(b, ssa_5, ssa_1);
if (p.invert_exit_condition_and_continue_branch)
ssa_3 = nir_inot(b, ssa_3);
nir_if *nif = nir_push_if(b, ssa_3);
{
if (p.invert_exit_condition_and_continue_branch)
nir_push_else(b, NULL);
nir_jump_instr *jump = nir_jump_instr_create(b->shader, nir_jump_break);
nir_builder_instr_insert(b, &jump->instr);
}
@ -199,7 +214,9 @@ TEST_F(nir_loop_analyze_test, one_iteration_fneu)
nir_loop *loop =
loop_builder(&b, {.init_value = 0xe7000000, .cond_value = 0xe7000000,
.incr_value = 0x5b000000,
.cond_instr = nir_fneu, .incr_instr = nir_fadd});
.cond_instr = nir_fneu, .incr_instr = nir_fadd,
.use_unknown_init_value = false,
.invert_exit_condition_and_continue_branch = false});
/* At this point, we should have:
*
@ -319,7 +336,9 @@ INOT_COMPARE(ilt_imin_rev)
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr}); \
.incr_instr = nir_ ## incr, \
.use_unknown_init_value = false, \
.invert_exit_condition_and_continue_branch = false}); \
\
nir_validate_shader(b.shader, "input"); \
\
@ -345,6 +364,42 @@ INOT_COMPARE(ilt_imin_rev)
} \
}
#define INEXACT_COUNT_TEST_UNKNOWN_INIT(_cond_value, _incr_value, cond, incr, count, invert) \
TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _inexact_count_ ## count ## _invert_ ## invert) \
{ \
nir_loop *loop = \
loop_builder(&b, {.init_value = 0, \
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr, \
.use_unknown_init_value = true, \
.invert_exit_condition_and_continue_branch = invert }); \
\
nir_validate_shader(b.shader, "input"); \
\
nir_loop_analyze_impl(b.impl, nir_var_all, false); \
\
ASSERT_NE((void *)0, loop->info); \
EXPECT_NE((void *)0, loop->info->limiting_terminator); \
EXPECT_EQ(count, loop->info->max_trip_count); \
EXPECT_FALSE(loop->info->exact_trip_count_known); \
\
EXPECT_EQ(2, loop->info->num_induction_vars); \
ASSERT_NE((void *)0, loop->info->induction_vars); \
\
const nir_loop_induction_variable *const ivars = \
loop->info->induction_vars; \
\
for (unsigned i = 0; i < loop->info->num_induction_vars; i++) { \
EXPECT_NE((void *)0, ivars[i].def); \
ASSERT_NE((void *)0, ivars[i].init_src); \
EXPECT_FALSE(nir_src_is_const(*ivars[i].init_src)); \
ASSERT_NE((void *)0, ivars[i].update_src); \
EXPECT_TRUE(nir_src_is_const(ivars[i].update_src->src)); \
} \
}
#define INEXACT_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, count) \
TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _inexact_count_ ## count) \
{ \
@ -353,7 +408,9 @@ INOT_COMPARE(ilt_imin_rev)
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr}); \
.incr_instr = nir_ ## incr, \
.use_unknown_init_value = false, \
.invert_exit_condition_and_continue_branch = false}); \
\
nir_validate_shader(b.shader, "input"); \
\
@ -387,7 +444,9 @@ INOT_COMPARE(ilt_imin_rev)
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr}); \
.incr_instr = nir_ ## incr, \
.use_unknown_init_value = false, \
.invert_exit_condition_and_continue_branch = false}); \
\
nir_validate_shader(b.shader, "input"); \
\
@ -407,7 +466,9 @@ INOT_COMPARE(ilt_imin_rev)
.cond_value = _cond_value, \
.incr_value = _incr_value, \
.cond_instr = nir_ ## cond, \
.incr_instr = nir_ ## incr}); \
.incr_instr = nir_ ## incr, \
.use_unknown_init_value = false, \
.invert_exit_condition_and_continue_branch = false}); \
\
nir_validate_shader(b.shader, "input"); \
\
@ -1636,3 +1697,27 @@ INEXACT_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt_imax, iadd, 5)
* }
*/
INEXACT_COUNT_TEST(0x00000001, 0x00000100, 0x00000001, uge_umin, ishl, 8)
/* uniform uint x;
* uint i = x;
* while (true) {
* if (i >= 4)
* break;
*
* i += 6;
* }
*/
INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 0)
/* uniform uint x;
* uint i = x;
* while (true) {
* if (!(i >= 4))
* continue;
* else
* break;
*
* i += 6;
* }
*/
INEXACT_COUNT_TEST_UNKNOWN_INIT(0x00000004, 0x00000006, uge, iadd, 1, 1)

View file

@ -41,6 +41,24 @@
EXPECT_EQ(_exp_loop_count, count_loops()); \
}
#define UNROLL_TEST_UNKNOWN_INIT_INSERT(_label, __type, _limit, _step, \
_cond, _incr, _rev, _exp_res, \
_exp_instr_count, _exp_loop_count)\
TEST_F(nir_loop_unroll_test, _label) \
{ \
nir_def *one = nir_imm_int(&bld, 1); \
nir_def *twelve = nir_imm_int(&bld, 12); \
nir_def *init = nir_load_ubo(&bld, 1, 32, one, twelve, (gl_access_qualifier)0, 0, 0, 0, 16); \
nir_def *limit = nir_imm_##__type(&bld, _limit); \
nir_def *step = nir_imm_##__type(&bld, _step); \
loop_unroll_test_helper(&bld, init, limit, step, \
&nir_##_cond, &nir_##_incr, _rev); \
EXPECT_##_exp_res(nir_opt_loop_unroll(bld.shader)); \
EXPECT_EQ(_exp_instr_count, count_instr(nir_op_##_incr)); \
EXPECT_EQ(_exp_loop_count, count_loops()); \
}
namespace {
class nir_loop_unroll_test : public ::testing::Test {
@ -175,3 +193,12 @@ UNROLL_TEST_INSERT(lshl_neg, int, 0xf0f0f0f0, 0, 1,
ige, ishl, false, TRUE, 4, 0)
UNROLL_TEST_INSERT(lshl_neg_rev, int, 0xf0f0f0f0, 0, 1,
ilt, ishl, true, TRUE, 4, 0)
UNROLL_TEST_UNKNOWN_INIT_INSERT(iadd_uge_unknown_init_gt, int, 4, 6,
uge, iadd, false, TRUE, 2, 0)
UNROLL_TEST_UNKNOWN_INIT_INSERT(iadd_uge_unknown_init_eq, int, 16, 4,
uge, iadd, false, TRUE, 5, 0)
UNROLL_TEST_UNKNOWN_INIT_INSERT(iadd_ugt_unknown_init_eq, int, 16, 4,
ult, iadd, true, TRUE, 6, 0)
UNROLL_TEST_UNKNOWN_INIT_INSERT(iadd_ige_unknown_init, int, 4, 6,
ige, iadd, false, FALSE, 1, 1)