From 7f1a64e7f50d4ca1751baf065fc3e99842b26491 Mon Sep 17 00:00:00 2001 From: Emma Anholt Date: Fri, 16 Jan 2026 16:59:13 -0800 Subject: [PATCH] nir/opt_algebraic_tests: Move more of the base class code to be methods. Less passing the *test around separately. Part-of: --- .../nir/tests/nir_algebraic_pattern_test.cpp | 46 +++++++++---------- .../nir/tests/nir_algebraic_pattern_test.h | 13 ++++-- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/compiler/nir/tests/nir_algebraic_pattern_test.cpp b/src/compiler/nir/tests/nir_algebraic_pattern_test.cpp index 3ebfba3a2b7..6d703daba11 100644 --- a/src/compiler/nir/tests/nir_algebraic_pattern_test.cpp +++ b/src/compiler/nir/tests/nir_algebraic_pattern_test.cpp @@ -17,10 +17,10 @@ nir_algebraic_pattern_test::nir_algebraic_pattern_test(const char *name) { } -static nir_const_value * -tmp_value(nir_algebraic_pattern_test *test, nir_def *def) +nir_const_value * +nir_algebraic_pattern_test::tmp_value(nir_def *def) { - return &test->tmp_values[def->index * NIR_MAX_VEC_COMPONENTS]; + return &tmp_values[def->index * NIR_MAX_VEC_COMPONENTS]; } static bool @@ -36,7 +36,7 @@ def_annotate_value(nir_def *def, void *data) FILE *output = u_memstream_get(&mem); - nir_const_value *value = tmp_value(test, def); + nir_const_value *value = test->tmp_value(def); fprintf(output, "// "); if (def->num_components == 1) { @@ -198,14 +198,14 @@ static const double float_inputs[INPUT_VALUE_COUNT] = { DBL_MIN, }; -static bool -skip_test(nir_algebraic_pattern_test *test, nir_alu_instr *alu, uint32_t bit_size, - nir_const_value tmp, int32_t src_index, bool exact) +bool +nir_algebraic_pattern_test::skip_test(nir_alu_instr *alu, uint32_t bit_size, + nir_const_value tmp, int32_t src_index) { /* Always pass the test for signed zero/nan/inf sources if they are not preserved. */ if (bit_size >= 16) { double val = nir_const_value_as_float(tmp, bit_size); - if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_signed_zero)) && val == 0.0 && signbit(val)) { + if ((!exact || !(fp_math_ctrl & nir_fp_preserve_signed_zero)) && val == 0.0 && signbit(val)) { /* TODO: Could be more permissive in covering input values -- right now * we skip if either before or after ever consume or produce a -0.0, * but if the result was unchanged by the 0.0 signs of the srcs, or if @@ -217,9 +217,9 @@ skip_test(nir_algebraic_pattern_test *test, nir_alu_instr *alu, uint32_t bit_siz */ return true; } - if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_nan)) && isnan(val)) + if ((!exact || !(fp_math_ctrl & nir_fp_preserve_nan)) && isnan(val)) return true; - if ((!exact || !(test->fp_math_ctrl & nir_fp_preserve_inf)) && isinf(val)) + if ((!exact || !(fp_math_ctrl & nir_fp_preserve_inf)) && isinf(val)) return true; } @@ -263,15 +263,15 @@ compare_inexact(double a, double b, uint32_t bit_size) * (either assert_eq was true, or we hit some UB with these inputs and the test should * be skipped). */ -static bool -evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) +bool +nir_algebraic_pattern_test::evaluate_expression(nir_instr *instr) { if (instr->type == nir_instr_type_intrinsic) { nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr); if (intrinsic->intrinsic == nir_intrinsic_unit_test_assert_eq) { - nir_const_value *src0 = tmp_value(test, intrinsic->src[0].ssa); - nir_const_value *src1 = tmp_value(test, intrinsic->src[1].ssa); + nir_const_value *src0 = tmp_value(intrinsic->src[0].ssa); + nir_const_value *src1 = tmp_value(intrinsic->src[1].ssa); assert(intrinsic->src[0].ssa->bit_size == intrinsic->src[1].ssa->bit_size); uint32_t bit_size = intrinsic->src[0].ssa->bit_size; @@ -296,7 +296,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) if (bit_size >= 16) { double af = nir_const_value_as_float(src0[comp], bit_size); double bf = nir_const_value_as_float(src1[comp], bit_size); - if (test->exact) { + if (exact) { if (!(is_float && isnan(af) && isnan(bf))) return false; } else { @@ -324,7 +324,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) nir_load_const_instr *load_const = nir_instr_as_load_const(instr); for (uint32_t i = 0; i < load_const->def.num_components; i++) - tmp_value(test, &load_const->def)[i] = load_const->value[i]; + tmp_value(&load_const->def)[i] = load_const->value[i]; return false; } @@ -342,10 +342,10 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) bit_size = alu->src[i].src.ssa->bit_size; for (uint32_t j = 0; j < nir_ssa_alu_instr_src_components(alu, i); j++) { - nir_const_value tmp = tmp_value(test, alu->src[i].src.ssa)[alu->src[i].swizzle[j]]; + nir_const_value tmp = tmp_value(alu->src[i].src.ssa)[alu->src[i].swizzle[j]]; src[i][j] = tmp; - if (skip_test(test, alu, alu->src[i].src.ssa->bit_size, tmp, i, test->exact)) + if (skip_test(alu, alu->src[i].src.ssa->bit_size, tmp, i)) return true; } } @@ -357,10 +357,10 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) for (uint32_t i = 0; i < nir_op_infos[alu->op].num_inputs; i++) srcs[i] = src[i]; - nir_const_value *dest = tmp_value(test, &alu->def); + nir_const_value *dest = tmp_value(&alu->def); nir_component_mask_t poison; - nir_eval_const_opcode(alu->op, dest, &poison, alu->def.num_components, bit_size, srcs, test->b->shader->info.float_controls_execution_mode); + nir_eval_const_opcode(alu->op, dest, &poison, alu->def.num_components, bit_size, srcs, b->shader->info.float_controls_execution_mode); /* If the inputs we chose triggered UB, then skip this particular test * combination -- we can't assert equality of the results (and we don't have @@ -371,7 +371,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) return true; for (uint32_t comp = 0; comp < alu->def.num_components; comp++) { - if (skip_test(test, alu, bit_size, dest[comp], -1, test->exact)) + if (skip_test(alu, bit_size, dest[comp], -1)) return true; } @@ -382,7 +382,7 @@ evaluate_expression(nir_algebraic_pattern_test *test, nir_instr *instr) void nir_algebraic_pattern_test::set_inputs(uint32_t seed) { - for (auto input: inputs) { + for (auto input : inputs) { nir_load_const_instr *load = input.instr; uint32_t seed_bit_size = get_seed_bit_size(input.ty); @@ -488,7 +488,7 @@ nir_algebraic_pattern_test::validate_pattern() bool passed_or_skipped = false; nir_foreach_instr(instr, block) { - if (evaluate_expression(this, instr)) { + if (evaluate_expression(instr)) { passed_or_skipped = true; if (instr->type == nir_instr_type_intrinsic) { if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_unit_test_assert_eq) diff --git a/src/compiler/nir/tests/nir_algebraic_pattern_test.h b/src/compiler/nir/tests/nir_algebraic_pattern_test.h index 8cb4e6cfb6d..d0e536d32c3 100644 --- a/src/compiler/nir/tests/nir_algebraic_pattern_test.h +++ b/src/compiler/nir/tests/nir_algebraic_pattern_test.h @@ -6,13 +6,13 @@ #ifndef NIR_ALGEBRAIC_PATTERN_TEST_H #define NIR_ALGEBRAIC_PATTERN_TEST_H -#include "nir.h" -#include "nir_test.h" -#include "nir_search.h" #include "gtest/gtest-spi.h" +#include "nir.h" +#include "nir_search.h" +#include "nir_test.h" class nir_algebraic_pattern_test_variable_cond { -public: + public: nir_algebraic_pattern_test_variable_cond(nir_alu_instr *alu, unsigned src_index, const nir_search_variable_cond cond) : alu(alu), src_index(src_index), cond(cond) { @@ -56,8 +56,13 @@ class nir_algebraic_pattern_test : public nir_test { void set_inputs(uint32_t seed); bool check_variable_conds(); void validate_pattern(); + bool evaluate_expression(nir_instr *instr); + bool skip_test(nir_alu_instr *alu, uint32_t bit_size, + nir_const_value tmp, int32_t src_index); public: + nir_const_value *tmp_value(nir_def *def); + std::vector inputs; uint32_t fuzzing_bits; bool exact = true;