nir/validate: Assume SSA

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24432>
This commit is contained in:
Alyssa Rosenzweig 2023-08-01 11:02:13 -04:00 committed by Marge Bot
parent 71699e59a3
commit 20f38b4b41

View file

@ -40,25 +40,9 @@
*/
#ifndef NDEBUG
/*
* Per-register validation state.
*/
typedef struct {
/*
* equivalent to the uses and defs in nir_register, but built up by the
* validator. At the end, we verify that the sets have the same entries.
*/
struct set *uses, *defs;
nir_function_impl *where_defined; /* NULL for global registers */
} reg_validate_state;
typedef struct {
void *mem_ctx;
/* map of register -> validation state (struct above) */
struct hash_table *regs;
/* the current shader being validated */
nir_shader *shader;
@ -95,9 +79,6 @@ typedef struct {
/* bitset of ssa definitions we have found; used to check uniqueness */
BITSET_WORD *ssa_defs_found;
/* bitset of registers we have currently found; used to check uniqueness */
BITSET_WORD *regs_found;
/* map of variable -> function implementation where it is defined or NULL
* if it is a global variable
*/
@ -147,35 +128,6 @@ validate_num_components(validate_state *state, unsigned num_components)
validate_assert(state, nir_num_components_valid(num_components));
}
static void
validate_reg_src(nir_src *src, validate_state *state,
unsigned bit_sizes, unsigned num_components)
{
validate_assert(state, src->reg.reg != NULL);
struct hash_entry *entry;
entry = _mesa_hash_table_search(state->regs, src->reg.reg);
validate_assert(state, entry);
reg_validate_state *reg_state = (reg_validate_state *) entry->data;
if (state->instr) {
_mesa_set_add(reg_state->uses, src);
} else {
validate_assert(state, state->if_stmt);
validate_assert(state, src->is_if);
_mesa_set_add(reg_state->uses, src);
}
validate_assert(state, reg_state->where_defined == state->impl &&
"using a register declared in a different function");
if (bit_sizes)
validate_assert(state, src->reg.reg->bit_size & bit_sizes);
if (num_components)
validate_assert(state, src->reg.reg->num_components == num_components);
}
static void
validate_ssa_src(nir_src *src, validate_state *state,
unsigned bit_sizes, unsigned num_components)
@ -209,10 +161,8 @@ validate_src(nir_src *src, validate_state *state,
else
validate_assert(state, src->parent_if == state->if_stmt);
if (src->is_ssa)
validate_ssa_src(src, state, bit_sizes, num_components);
else
validate_reg_src(src, state, bit_sizes, num_components);
validate_assert(state, src->is_ssa);
validate_ssa_src(src, state, bit_sizes, num_components);
}
static void
@ -234,31 +184,6 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state)
validate_src(&src->src, state, 0, 0);
}
static void
validate_reg_dest(nir_register_dest *dest, validate_state *state,
unsigned bit_sizes, unsigned num_components)
{
validate_assert(state, dest->reg != NULL);
validate_assert(state, dest->parent_instr == state->instr);
struct hash_entry *entry2;
entry2 = _mesa_hash_table_search(state->regs, dest->reg);
validate_assert(state, entry2);
reg_validate_state *reg_state = (reg_validate_state *) entry2->data;
_mesa_set_add(reg_state->defs, dest);
validate_assert(state, reg_state->where_defined == state->impl &&
"writing to a register declared in a different function");
if (bit_sizes)
validate_assert(state, dest->reg->bit_size & bit_sizes);
if (num_components)
validate_assert(state, dest->reg->num_components == num_components);
}
static void
validate_ssa_def(nir_ssa_def *def, validate_state *state)
{
@ -285,15 +210,12 @@ static void
validate_dest(nir_dest *dest, validate_state *state,
unsigned bit_sizes, unsigned num_components)
{
if (dest->is_ssa) {
if (bit_sizes)
validate_assert(state, dest->ssa.bit_size & bit_sizes);
if (num_components)
validate_assert(state, dest->ssa.num_components == num_components);
validate_ssa_def(&dest->ssa, state);
} else {
validate_reg_dest(&dest->reg, state, bit_sizes, num_components);
}
validate_assert(state, dest->is_ssa);
if (bit_sizes)
validate_assert(state, dest->ssa.bit_size & bit_sizes);
if (num_components)
validate_assert(state, dest->ssa.num_components == num_components);
validate_ssa_def(&dest->ssa, state);
}
static void
@ -1460,49 +1382,6 @@ validate_cf_node(nir_cf_node *node, validate_state *state)
}
}
static void
prevalidate_reg_decl(nir_register *reg, validate_state *state)
{
validate_assert(state, reg->index < state->impl->reg_alloc);
validate_assert(state, !BITSET_TEST(state->regs_found, reg->index));
validate_num_components(state, reg->num_components);
BITSET_SET(state->regs_found, reg->index);
list_validate(&reg->uses);
list_validate(&reg->defs);
reg_validate_state *reg_state = ralloc(state->regs, reg_validate_state);
reg_state->uses = _mesa_pointer_set_create(reg_state);
reg_state->defs = _mesa_pointer_set_create(reg_state);
reg_state->where_defined = state->impl;
_mesa_hash_table_insert(state->regs, reg, reg_state);
}
static void
postvalidate_reg_decl(nir_register *reg, validate_state *state)
{
struct hash_entry *entry = _mesa_hash_table_search(state->regs, reg);
assume(entry);
reg_validate_state *reg_state = (reg_validate_state *) entry->data;
nir_foreach_use_including_if(src, reg) {
struct set_entry *entry = _mesa_set_search(reg_state->uses, src);
validate_assert(state, entry);
_mesa_set_remove(reg_state->uses, entry);
}
validate_assert(state, reg_state->uses->entries == 0);
nir_foreach_def(src, reg) {
struct set_entry *entry = _mesa_set_search(reg_state->defs, src);
validate_assert(state, entry);
_mesa_set_remove(reg_state->defs, entry);
}
validate_assert(state, reg_state->defs->entries == 0);
}
static void
validate_constant(nir_constant *c, const struct glsl_type *type,
validate_state *state)
@ -1680,15 +1559,6 @@ validate_function_impl(nir_function_impl *impl, validate_state *state)
validate_var_decl(var, nir_var_function_temp, state);
}
state->regs_found = reralloc(state->mem_ctx, state->regs_found,
BITSET_WORD, BITSET_WORDS(impl->reg_alloc));
memset(state->regs_found, 0, BITSET_WORDS(impl->reg_alloc) *
sizeof(BITSET_WORD));
exec_list_validate(&impl->registers);
foreach_list_typed(nir_register, reg, node, &impl->registers) {
prevalidate_reg_decl(reg, state);
}
state->ssa_defs_found = reralloc(state->mem_ctx, state->ssa_defs_found,
BITSET_WORD, BITSET_WORDS(impl->ssa_alloc));
memset(state->ssa_defs_found, 0, BITSET_WORDS(impl->ssa_alloc) *
@ -1704,10 +1574,6 @@ validate_function_impl(nir_function_impl *impl, validate_state *state)
}
validate_end_block(impl->end_block, state);
foreach_list_typed(nir_register, reg, node, &impl->registers) {
postvalidate_reg_decl(reg, state);
}
validate_assert(state, state->ssa_srcs->entries == 0);
_mesa_set_clear(state->ssa_srcs, NULL);
@ -1736,10 +1602,8 @@ static void
init_validate_state(validate_state *state)
{
state->mem_ctx = ralloc_context(NULL);
state->regs = _mesa_pointer_hash_table_create(state->mem_ctx);
state->ssa_srcs = _mesa_pointer_set_create(state->mem_ctx);
state->ssa_defs_found = NULL;
state->regs_found = NULL;
state->blocks = _mesa_pointer_set_create(state->mem_ctx);
state->var_defs = _mesa_pointer_hash_table_create(state->mem_ctx);
state->errors = _mesa_pointer_hash_table_create(state->mem_ctx);