diff --git a/src/compiler/nir/nir_loop_analyze.c b/src/compiler/nir/nir_loop_analyze.c index f468d2316e5..e06a8f61392 100644 --- a/src/compiler/nir/nir_loop_analyze.c +++ b/src/compiler/nir/nir_loop_analyze.c @@ -717,11 +717,11 @@ eval_const_binop(nir_op op, unsigned bit_size, } static int -find_replacement(const nir_def **originals, const nir_def *key, +find_replacement(const nir_scalar *originals, nir_scalar key, unsigned num_replacements) { for (int i = 0; i < num_replacements; i++) { - if (originals[i] == key) + if (nir_scalar_equal(originals[i], key)) return i; } @@ -750,12 +750,14 @@ find_replacement(const nir_def **originals, const nir_def *key, * applying the previously described substitution) or false otherwise. */ static bool -try_eval_const_alu(nir_const_value *dest, nir_alu_instr *alu, - const nir_def **originals, - const nir_const_value **replacements, +try_eval_const_alu(nir_const_value *dest, nir_scalar alu_s, const nir_scalar *originals, + const nir_const_value *replacements, unsigned num_replacements, unsigned execution_mode) { - nir_const_value src[NIR_MAX_VEC_COMPONENTS][NIR_MAX_VEC_COMPONENTS]; + nir_alu_instr *alu = nir_instr_as_alu(alu_s.def->parent_instr); + + if (nir_op_infos[alu->op].output_size) + return false; /* In the case that any outputs/inputs have unsized types, then we need to * guess the bit-size. In this case, the validator ensures that all @@ -767,55 +769,42 @@ try_eval_const_alu(nir_const_value *dest, nir_alu_instr *alu, * (although it still requires to receive a valid bit-size). */ unsigned bit_size = 0; - if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].output_type)) + if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].output_type)) { bit_size = alu->def.bit_size; + } else { + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { + if (!nir_alu_type_get_type_size(nir_op_infos[alu->op].input_types[i])) + bit_size = alu->src[i].src.ssa->bit_size; + } + + if (bit_size == 0) + bit_size = 32; + } + + nir_const_value src[NIR_MAX_VEC_COMPONENTS]; + nir_const_value *src_ptrs[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { - if (bit_size == 0 && - !nir_alu_type_get_type_size(nir_op_infos[alu->op].input_types[i])) - bit_size = alu->src[i].src.ssa->bit_size; + nir_scalar src_s = nir_scalar_chase_alu_src(alu_s, i); - nir_instr *src_instr = alu->src[i].src.ssa->parent_instr; + src_ptrs[i] = &src[i]; + if (nir_scalar_is_const(src_s)) { + src[i] = nir_scalar_as_const_value(src_s); + continue; + } - if (src_instr->type == nir_instr_type_load_const) { - nir_load_const_instr *load_const = nir_instr_as_load_const(src_instr); - - for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i); - j++) { - src[i][j] = load_const->value[alu->src[i].swizzle[j]]; - } - } else { - int r = find_replacement(originals, alu->src[i].src.ssa, - num_replacements); - - if (r >= 0) { - for (unsigned j = 0; j < nir_ssa_alu_instr_src_components(alu, i); - j++) { - src[i][j] = replacements[r][alu->src[i].swizzle[j]]; - } - } else if (src_instr->type == nir_instr_type_alu) { - memset(src[i], 0, sizeof(src[i])); - - if (!try_eval_const_alu(src[i], nir_instr_as_alu(src_instr), - originals, replacements, num_replacements, - execution_mode)) - return false; - } else { - return false; - } + int r = find_replacement(originals, src_s, num_replacements); + if (r >= 0) { + src[i] = replacements[r]; + } else if (!nir_scalar_is_alu(src_s) || + !try_eval_const_alu(&src[i], src_s, + originals, replacements, + num_replacements, execution_mode)) { + return false; } } - if (bit_size == 0) - bit_size = 32; - - nir_const_value *srcs[NIR_MAX_VEC_COMPONENTS]; - - for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) - srcs[i] = src[i]; - - nir_eval_const_opcode(alu->op, dest, alu->def.num_components, - bit_size, srcs, execution_mode); + nir_eval_const_opcode(alu->op, dest, 1, bit_size, src_ptrs, execution_mode); return true; } @@ -931,13 +920,14 @@ get_iteration_empirical(nir_alu_instr *cond_alu, nir_alu_instr *incr_alu, nir_const_value result; nir_const_value iter = initial; - const nir_def *originals[2] = { basis, NULL }; - const nir_const_value *replacements[2] = { &iter, NULL }; + const nir_scalar original = nir_get_scalar(basis, 0); + const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0); + const nir_scalar incr = nir_get_scalar(&incr_alu->def, 0); while (iter_count <= max_unroll_iterations) { bool success; - success = try_eval_const_alu(&result, cond_alu, originals, replacements, + success = try_eval_const_alu(&result, cond, &original, &iter, 1, execution_mode); if (!success) return -1; @@ -948,7 +938,7 @@ get_iteration_empirical(nir_alu_instr *cond_alu, nir_alu_instr *incr_alu, iter_count++; - success = try_eval_const_alu(&result, incr_alu, originals, replacements, + success = try_eval_const_alu(&result, incr, &original, &iter, 1, execution_mode); assert(success); @@ -966,10 +956,11 @@ will_break_on_first_iteration(nir_alu_instr *cond_alu, nir_def *basis, { nir_const_value result; - const nir_def *originals[2] = { basis, limit_basis }; - const nir_const_value *replacements[2] = { &initial, &limit }; + const nir_scalar originals[2] = { nir_get_scalar(basis, 0), nir_get_scalar(limit_basis, 0) }; + const nir_const_value replacements[2] = { initial, limit }; - ASSERTED bool success = try_eval_const_alu(&result, cond_alu, originals, + const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0); + ASSERTED bool success = try_eval_const_alu(&result, cond, originals, replacements, 2, execution_mode); assert(success);