nir/loop_analyze: insert only induction vars into hash map

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33131>
This commit is contained in:
Daniel Schürmann 2025-01-17 12:50:09 +01:00 committed by Marge Bot
parent f0fd04327f
commit 7eb2e96d16

View file

@ -27,18 +27,10 @@
#include "nir.h"
#include "nir_constant_expressions.h"
typedef enum {
undefined,
basic_induction
} nir_loop_variable_type;
typedef struct nir_loop_variable {
/* The ssa_def associated with this info */
nir_def *def;
/* The type of this ssa_def */
nir_loop_variable_type type;
/* Could be a basic_induction if following uniforms are inlined */
nir_src *init_src;
nir_alu_src *update_src;
@ -72,11 +64,8 @@ get_loop_var(nir_def *value, loop_info_state *state)
struct hash_entry *entry = _mesa_hash_table_search(state->loop_vars, value);
if (entry)
return entry->data;
nir_loop_variable *var = rzalloc(state, nir_loop_variable);
var->def = value;
_mesa_hash_table_insert(state->loop_vars, value, var);
return var;
else
return NULL;
}
/** Calculate an estimated cost in number of instructions
@ -116,9 +105,9 @@ instr_cost(loop_info_state *state, nir_instr *instr,
* the loop is unrolled so here we assign it a cost of 0.
*/
if ((nir_src_is_const(sel_alu->src[0].src) &&
get_loop_var(rhs.def, state)->type == basic_induction) ||
get_loop_var(rhs.def, state)) ||
(nir_src_is_const(sel_alu->src[1].src) &&
get_loop_var(lhs.def, state)->type == basic_induction)) {
get_loop_var(lhs.def, state))) {
/* Also if the selects condition is only used by the select then
* remove that alu instructons cost from the cost total also.
*/
@ -185,18 +174,6 @@ instr_cost(loop_info_state *state, nir_instr *instr,
}
}
static inline bool
is_var_alu(nir_loop_variable *var)
{
return var->def->parent_instr->type == nir_instr_type_alu;
}
static inline bool
is_var_phi(nir_loop_variable *var)
{
return var->def->parent_instr->type == nir_instr_type_phi;
}
/* If all of the instruction sources point to identical ALU instructions (as
* per nir_instrs_equal), return one of the ALU instructions. Otherwise,
* return NULL.
@ -286,38 +263,36 @@ compute_induction_information(loop_info_state *state)
info->num_induction_vars = 0;
nir_foreach_phi(phi, header) {
nir_loop_variable *var = get_loop_var(&phi->def, state);
nir_loop_variable *alu_src_var = NULL;
nir_loop_variable var = { .basis = &phi->def };
nir_foreach_phi_src(src, phi) {
if (src->pred == preheader) {
var->init_src = &src->src;
nir_foreach_phi_src(phi_src, phi) {
nir_def *src = phi_src->src.ssa;
if (phi_src->pred == preheader) {
var.init_src = &phi_src->src;
continue;
}
/* If one of the sources is in an if branch or nested loop then don't
* attempt to go any further.
*/
if (src->src.ssa->parent_instr->block->cf_node.parent != &state->loop->cf_node)
if (src->parent_instr->block->cf_node.parent != &state->loop->cf_node)
break;
nir_loop_variable *src_var = get_loop_var(src->src.ssa, state);
/* Detect inductions variables that are incremented in both branches
* of an unnested if rather than in a loop block.
*/
if (is_var_phi(src_var)) {
nir_phi_instr *src_phi =
nir_instr_as_phi(src_var->def->parent_instr);
if (src->parent_instr->type == nir_instr_type_phi) {
nir_phi_instr *src_phi = nir_instr_as_phi(src->parent_instr);
nir_alu_instr *src_phi_alu = phi_instr_as_alu(src_phi);
if (src_phi_alu) {
src_var = get_loop_var(&src_phi_alu->def, state);
src = &src_phi_alu->def;
}
}
if (is_var_alu(src_var) && !var->update_src) {
alu_src_var = src_var;
nir_alu_instr *alu = nir_instr_as_alu(src_var->def->parent_instr);
if (src->parent_instr->type == nir_instr_type_alu && !var.update_src) {
var.def = src;
nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr);
/* Check for unsupported alu operations */
if (alu->op != nir_op_iadd && alu->op != nir_op_fadd &&
@ -334,45 +309,43 @@ compute_induction_information(loop_info_state *state)
if (alu->src[1 - i].src.ssa == &phi->def &&
alu_src_has_identity_swizzle(alu, 1 - i)) {
if (is_only_uniform_src(&alu->src[i].src))
var->update_src = alu->src + i;
var.update_src = alu->src + i;
}
}
}
if (!var->update_src)
if (!var.update_src)
break;
} else {
var->update_src = NULL;
var.update_src = NULL;
break;
}
}
if (var->update_src && var->init_src &&
is_only_uniform_src(var->init_src)) {
alu_src_var->init_src = var->init_src;
alu_src_var->update_src = var->update_src;
alu_src_var->basis = var->def;
alu_src_var->type = basic_induction;
if (var.update_src && var.init_src &&
is_only_uniform_src(var.init_src)) {
/* Insert induction variables into hash table. */
nir_loop_variable *alu_src_var = ralloc(state, nir_loop_variable);
*alu_src_var = var;
_mesa_hash_table_insert(state->loop_vars, alu_src_var->def, alu_src_var);
var->basis = var->def;
var->type = basic_induction;
nir_loop_variable *phi_def_var = ralloc(state, nir_loop_variable);
*phi_def_var = var;
phi_def_var->def = var.basis;
_mesa_hash_table_insert(state->loop_vars, phi_def_var->def, phi_def_var);
/* record induction variables into nir_loop_info */
nir_loop_induction_variable *ivar;
ivar = &info->induction_vars[info->num_induction_vars++];
ivar->def = var->def;
ivar->init_src = var->init_src;
ivar->update_src = var->update_src;
ivar->def = var.basis;
ivar->init_src = var.init_src;
ivar->update_src = var.update_src;
ivar = &info->induction_vars[info->num_induction_vars++];
ivar->def = alu_src_var->def;
ivar->init_src = alu_src_var->init_src;
ivar->update_src = alu_src_var->update_src;
ivar->def = var.def;
ivar->init_src = var.init_src;
ivar->update_src = var.update_src;
/* don't overflow */
assert(info->num_induction_vars <= num_phis * 2);
} else {
var->init_src = NULL;
var->update_src = NULL;
var->basis = NULL;
}
}
@ -455,7 +428,7 @@ find_array_access_via_induction(loop_info_state *state,
nir_loop_variable *array_index = get_loop_var(d->arr.index.ssa, state);
if (array_index->type != basic_induction)
if (!array_index)
continue;
if (array_index_out)
@ -1059,12 +1032,12 @@ get_induction_and_limit_vars(nir_scalar cond,
nir_loop_variable *src0_lv = get_loop_var(lhs.def, state);
nir_loop_variable *src1_lv = get_loop_var(rhs.def, state);
if (src0_lv->type == basic_induction) {
if (src0_lv) {
*ind = lhs;
*limit = rhs;
*limit_rhs = true;
return true;
} else if (src1_lv->type == basic_induction) {
} else if (src1_lv) {
*ind = rhs;
*limit = lhs;
*limit_rhs = false;