nir/opt_algebraic_tests: Move more of the base class code to be methods.

Less passing the *test around separately.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39369>
This commit is contained in:
Emma Anholt 2026-01-16 16:59:13 -08:00 committed by Marge Bot
parent 845e2b3954
commit 7f1a64e7f5
2 changed files with 32 additions and 27 deletions

View file

@ -17,10 +17,10 @@ nir_algebraic_pattern_test::nir_algebraic_pattern_test(const char *name)
{
}
static nir_const_value *
tmp_value(nir_algebraic_pattern_test *test, nir_def *def)
nir_const_value *
nir_algebraic_pattern_test::tmp_value(nir_def *def)
{
return &test->tmp_values[def->index * NIR_MAX_VEC_COMPONENTS];
return &tmp_values[def->index * NIR_MAX_VEC_COMPONENTS];
}
static bool
@ -36,7 +36,7 @@ def_annotate_value(nir_def *def, void *data)
FILE *output = u_memstream_get(&mem);
nir_const_value *value = tmp_value(test, def);
nir_const_value *value = test->tmp_value(def);
fprintf(output, "// ");
if (def->num_components == 1) {
@ -198,14 +198,14 @@ static const double float_inputs[INPUT_VALUE_COUNT] = {
DBL_MIN,
};
static bool
skip_test(nir_algebraic_pattern_test *test, nir_alu_instr *alu, uint32_t bit_size,
nir_const_value tmp, int32_t src_index, bool exact)
bool
nir_algebraic_pattern_test::skip_test(nir_alu_instr *alu, uint32_t bit_size,
nir_const_value tmp, int32_t src_index)
{
/* Always pass the test for signed zero/nan/inf sources if they are not preserved. */
if (bit_size >= 16) {
double val = nir_const_value_as_float(tmp, bit_size);
if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_signed_zero)) && val == 0.0 && signbit(val)) {
if ((!exact || !(fp_math_ctrl & nir_fp_preserve_signed_zero)) && val == 0.0 && signbit(val)) {
/* TODO: Could be more permissive in covering input values -- right now
* we skip if either before or after ever consume or produce a -0.0,
* but if the result was unchanged by the 0.0 signs of the srcs, or if
@ -217,9 +217,9 @@ skip_test(nir_algebraic_pattern_test *test, nir_alu_instr *alu, uint32_t bit_siz
*/
return true;
}
if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_nan)) && isnan(val))
if ((!exact || !(fp_math_ctrl & nir_fp_preserve_nan)) && isnan(val))
return true;
if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_inf)) && isinf(val))
if ((!exact || !(fp_math_ctrl & nir_fp_preserve_inf)) && isinf(val))
return true;
}
@ -263,15 +263,15 @@ compare_inexact(double a, double b, uint32_t bit_size)
* (either assert_eq was true, or we hit some UB with these inputs and the test should
* be skipped).
*/
static bool
evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
bool
nir_algebraic_pattern_test::evaluate_expression(nir_instr *instr)
{
if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
if (intrinsic->intrinsic == nir_intrinsic_unit_test_assert_eq) {
nir_const_value *src0 = tmp_value(test, intrinsic->src[0].ssa);
nir_const_value *src1 = tmp_value(test, intrinsic->src[1].ssa);
nir_const_value *src0 = tmp_value(intrinsic->src[0].ssa);
nir_const_value *src1 = tmp_value(intrinsic->src[1].ssa);
assert(intrinsic->src[0].ssa->bit_size == intrinsic->src[1].ssa->bit_size);
uint32_t bit_size = intrinsic->src[0].ssa->bit_size;
@ -296,7 +296,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
if (bit_size >= 16) {
double af = nir_const_value_as_float(src0[comp], bit_size);
double bf = nir_const_value_as_float(src1[comp], bit_size);
if (test->exact) {
if (exact) {
if (!(is_float && isnan(af) && isnan(bf)))
return false;
} else {
@ -324,7 +324,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
for (uint32_t i = 0; i < load_const->def.num_components; i++)
tmp_value(test, &load_const->def)[i] = load_const->value[i];
tmp_value(&load_const->def)[i] = load_const->value[i];
return false;
}
@ -342,10 +342,10 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
bit_size = alu->src[i].src.ssa->bit_size;
for (uint32_t j = 0; j < nir_ssa_alu_instr_src_components(alu, i); j++) {
nir_const_value tmp = tmp_value(test, alu->src[i].src.ssa)[alu->src[i].swizzle[j]];
nir_const_value tmp = tmp_value(alu->src[i].src.ssa)[alu->src[i].swizzle[j]];
src[i][j] = tmp;
if (skip_test(test, alu, alu->src[i].src.ssa->bit_size, tmp, i, test->exact))
if (skip_test(alu, alu->src[i].src.ssa->bit_size, tmp, i))
return true;
}
}
@ -357,10 +357,10 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
for (uint32_t i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
srcs[i] = src[i];
nir_const_value *dest = tmp_value(test, &alu->def);
nir_const_value *dest = tmp_value(&alu->def);
nir_component_mask_t poison;
nir_eval_const_opcode(alu->op, dest, &poison, alu->def.num_components, bit_size, srcs, test->b->shader->info.float_controls_execution_mode);
nir_eval_const_opcode(alu->op, dest, &poison, alu->def.num_components, bit_size, srcs, b->shader->info.float_controls_execution_mode);
/* If the inputs we chose triggered UB, then skip this particular test
* combination -- we can't assert equality of the results (and we don't have
@ -371,7 +371,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
return true;
for (uint32_t comp = 0; comp < alu->def.num_components; comp++) {
if (skip_test(test, alu, bit_size, dest[comp], -1, test->exact))
if (skip_test(alu, bit_size, dest[comp], -1))
return true;
}
@ -382,7 +382,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr)
void
nir_algebraic_pattern_test::set_inputs(uint32_t seed)
{
for (auto input: inputs) {
for (auto input : inputs) {
nir_load_const_instr *load = input.instr;
uint32_t seed_bit_size = get_seed_bit_size(input.ty);
@ -488,7 +488,7 @@ nir_algebraic_pattern_test::validate_pattern()
bool passed_or_skipped = false;
nir_foreach_instr(instr, block) {
if (evaluate_expression(this, instr)) {
if (evaluate_expression(instr)) {
passed_or_skipped = true;
if (instr->type == nir_instr_type_intrinsic) {
if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_unit_test_assert_eq)

View file

@ -6,13 +6,13 @@
#ifndef NIR_ALGEBRAIC_PATTERN_TEST_H
#define NIR_ALGEBRAIC_PATTERN_TEST_H
#include "nir.h"
#include "nir_test.h"
#include "nir_search.h"
#include "gtest/gtest-spi.h"
#include "nir.h"
#include "nir_search.h"
#include "nir_test.h"
class nir_algebraic_pattern_test_variable_cond {
public:
public:
nir_algebraic_pattern_test_variable_cond(nir_alu_instr *alu, unsigned src_index, const nir_search_variable_cond cond)
: alu(alu), src_index(src_index), cond(cond)
{
@ -56,8 +56,13 @@ class nir_algebraic_pattern_test : public nir_test {
void set_inputs(uint32_t seed);
bool check_variable_conds();
void validate_pattern();
bool evaluate_expression(nir_instr *instr);
bool skip_test(nir_alu_instr *alu, uint32_t bit_size,
nir_const_value tmp, int32_t src_index);
public:
nir_const_value *tmp_value(nir_def *def);
std::vector<nir_algebraic_pattern_test_input> inputs;
uint32_t fuzzing_bits;
bool exact = true;