nir/opt_algebraic_tests: Add an option for generating unit tests

It only emits tests for exact patterns which do not use instructions
that drop precision by design.

Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39076>
This commit is contained in:
Konstantin Seurer 2024-07-02 10:06:49 +02:00 committed by Marge Bot
parent 14fafebc1a
commit f5864ed408

View file

@ -30,6 +30,14 @@ import re
import traceback
from nir_opcodes import opcodes, type_sizes
from enum import Enum
class TestStatus(Enum):
PASS = 0,
XFAIL = 1,
UNSUPPORTED = 2,
# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
nir_search_max_comm_ops = 8
@ -319,6 +327,12 @@ _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
r"$")
swizzles = {'x': 0, 'y': 1, 'z': 2, 'w': 3,
'a': 0, 'b': 1, 'c': 2, 'd': 3,
'e': 4, 'f': 5, 'g': 6, 'h': 7,
'i': 8, 'j': 9, 'k': 10, 'l': 11,
'm': 12, 'n': 13, 'o': 14, 'p': 15}
class Variable(Value):
def __init__(self, val, name, varset, algebraic_pass):
@ -338,6 +352,7 @@ class Variable(Value):
assert self.var_name != 'False'
self.is_constant = m.group('const') is not None
self.cond = m.group('cond')
self.cond_index = get_cond_index(
algebraic_pass.variable_cond, m.group('cond'))
self.required_type = m.group('type')
@ -826,6 +841,11 @@ class SearchAndReplace(object):
else:
self.condition = 'true'
if len(transform) > 3:
self.test_status = transform[3]
else:
self.test_status = TestStatus.PASS
if self.condition not in condition_list:
condition_list.append(self.condition)
self.condition_index = condition_list.index(self.condition)
@ -1145,7 +1165,7 @@ ${xform.replace.render(cache)}
};
% if expression_cond:
static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
UNUSED static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
% for cond in expression_cond:
${cond[0]},
% endfor
@ -1242,10 +1262,246 @@ ${pass_name}(
}
""")
_algebraic_pass_pattern_test_template = mako.template.Template("""
#include <math.h>
#include "tests/nir_algebraic_pattern_test.h"
#include "nir_search_helpers.h"
% if variable_cond:
UNUSED static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
% for cond in variable_cond:
${cond[0]},
% endfor
};
% endif
class ${pass_name}_pattern_test : public nir_algebraic_pattern_test {
protected:
${pass_name}_pattern_test()
: nir_algebraic_pattern_test("${pass_name}_pattern_test")
{
}
};
% for subset, chunk in enumerate(chunks):
#if SUBSET == ${subset}
% for test_name, verbose_name, xform_defs, expr_conds, search_def, replace_def, test_status, expected_result in chunk:
TEST_F(${pass_name}_pattern_test, ${test_name})
{
b->shader->info.name = "${verbose_name}";
% if expected_result == test_status.XFAIL:
expected_result = XFAIL;
% elif expected_result == test_status.UNSUPPORTED:
expected_result = UNSUPPORTED;
% endif
% for xform_def in xform_defs:
${xform_def}
% endfor
nir_unit_test_assert_eq(b, ${search_def}, ${replace_def});
% for cond in expr_conds:
${cond}
% endfor
validate_pattern();
}
% endfor
#endif /* SUBSET == ${subset} */
% endfor
""")
def expression_has_float(expr):
if isinstance(expr, (Variable, Constant)):
return False
if any(expression_has_float(src) for src in expr.sources):
return True
opcode = expr.opcode
if opcode in conv_opcode_types:
opcode += "32"
return "float" in opcodes[opcode].output_type or any("float" in src for src in opcodes[opcode].input_types)
def expression_is_unsupported(expr):
if isinstance(expr, Constant) or isinstance(expr, Variable):
return False
if any(expression_is_unsupported(src) for src in expr.sources):
return True
if expr.swizzle != -1:
return True
broken_opcodes = [
# medium precision means that the compiler can do whatever it wants which makes it unsuitable for testing.
"f2fmp", "i2imp", "f2imp", "f2ump", "i2fmp", "u2fmp",
# _replicated OPs do not have nir_builder functions.
"fdot2_replicated", "fdot3_replicated", "fdot4_replicated", "fdph_replicated",
# The tests do not validate patterns with those opcodes correctly.
"imad24_ir3", "imul24_relaxed", "umad24_relaxed", "umul24_relaxed",
"udiv_aligned_4",
]
if expr.opcode in broken_opcodes:
return True
# These are just too slow to evaluate on our qemu cross builds. The
# 2/3 cases should cover our patterns well enough.
if re.match(r"(bany|ball).*equal(4|8|16)", expr.opcode):
return True
return False
def expression_is_inexact(expr):
if isinstance(expr, (Variable, Constant)):
return False
if any(expression_is_inexact(src) for src in expr.sources):
return True
return expr.inexact or expr.contract
def get_expression_name(expr):
name = expr.opcode
for src in expr.sources:
if isinstance(src, Expression):
name += "_" + get_expression_name(src)
return name
def get_value_comps(expr, value_comps, num_components=1):
if isinstance(expr, Variable):
value_comps[expr.index] = max(value_comps.get(
expr.index, num_components), num_components)
if expr.swiz is not None:
for comp in [swizzles[c] for c in expr.swiz[1:]]:
value_comps[expr.index] = max(
value_comps[expr.index], comp + 1)
return
if isinstance(expr, Constant):
value_comps[expr] = max(value_comps.get(
expr, num_components), num_components)
return
opcode = expr.opcode
if opcode in conv_opcode_types:
opcode += "32"
for src_num_components, src in zip(opcodes[opcode].input_sizes, expr.sources):
src_num_components = max(src_num_components, 1)
get_value_comps(src, value_comps, src_num_components)
def get_expression_def(expr, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name):
bit_size = 32 if expr.c_bit_size <= 0 else expr.c_bit_size
if isinstance(expr, Variable):
if expr.index not in variable_map:
def_name = f"{name}{len(defs)}"
num_components = value_comps[expr.index]
if expr.required_type == "bool":
defs.append(
f"nir_def *{def_name} = nir_b2b{bit_size}(b, nir_unit_test_uniform_input(b, {num_components}, 1, {expr.index}));")
else:
defs.append(
f"nir_def *{def_name} = nir_unit_test_uniform_input(b, {num_components}, {bit_size}, {expr.index});")
variable_map[expr.index] = def_name
else:
def_name = variable_map[expr.index]
if expr.swiz is not None:
swizzle_name = f"{name}{len(defs)}_swizzle"
defs.append(
f"uint32_t {swizzle_name}[{len(expr.swiz) - 1}] = {expr.swizzle()};")
def_name = f"nir_swizzle(b, {def_name}, {swizzle_name}, {len(expr.swiz) - 1})"
return def_name
if isinstance(expr, Constant):
def_name = f"{name}{len(defs)}"
if isinstance(expr.value, bool):
defs.append(
f"nir_def *{def_name} = nir_imm_{"true" if expr.value else "false"}(b);")
elif isinstance(expr.value, int):
defs.append(
f"nir_def *{def_name} = nir_imm_intN_t(b, {expr.value}llu, {bit_size});")
elif isinstance(expr.value, float):
value = str(expr.value)
if value == "nan":
value = "NAN"
elif value == "-nan":
value = "-NAN"
elif value == "inf":
value = "INFINITY"
elif value == "-inf":
value = "-INFINITY"
defs.append(
f"nir_def *{def_name} = nir_imm_floatN_t(b, {value}, {bit_size});")
if value_comps[expr] != 1:
comps = ", ".join(def_name for c in range(0, value_comps[expr]))
defs.append(
f"nir_def *{def_name}_tmp[{value_comps[expr]}] = {{{comps}}}; {def_name} = nir_vec(b, {def_name}_tmp, {value_comps[expr]});")
return def_name
srcs = [get_expression_def(src, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name)
for src in expr.sources]
opcode = expr.opcode
if opcode in conv_opcode_types:
opcode += str(bit_size)
def_name = f"{name}{len(defs)}"
if expr.nsz:
fp_math_ctrl.add("nir_fp_preserve_signed_zero")
if expr.nnan:
fp_math_ctrl.add("nir_fp_preserve_nan")
if expr.ninf:
fp_math_ctrl.add("nir_fp_preserve_inf")
defs.append(f"nir_def *{def_name} = nir_{opcode}(b, {", ".join(srcs)});")
if expr.cond != None and isinstance(expr, Expression):
# These do not matter for correctness
if expr.cond not in ["is_used_once", "is_not_used_by_if"]:
expr_conds.append(f"if (!{expr.cond}(nir_def_as_alu({def_name})))")
expr_conds.append(
f" expression_cond_failed = \"{expr.cond} for {def_name}\";")
for src_index, expr in enumerate(expr.sources):
# We don't include not-const checks, since we implement our test
# evaluation by using load_consts for inputs, while the intent of
# patterns using them is "don't do this transformation when there are
# actual constants rather than variable values here." Similarly for
# is_fmul in fma distribution
if isinstance(expr, Variable) and expr.cond_index != -1 and expr.cond not in {"(is_not_const)", "(is_not_const_zero)", "(is_fmul)", "(is_not_const_and_not_fsign)"}:
defs.append(
f"variable_conds.push_back(nir_algebraic_pattern_test_variable_cond(nir_def_as_alu({def_name}), {src_index}, {pass_name}_variable_cond[{expr.cond_index}]));")
return def_name
class AlgebraicPass(object):
# params is a list of `("type", "name")` tuples
def __init__(self, pass_name, transforms, params=[]):
def __init__(self, pass_name, transforms, params=[], build_tests=False):
self.xforms = []
self.opcode_xforms = defaultdict(lambda: [])
self.pass_name = pass_name
@ -1255,6 +1511,11 @@ class AlgebraicPass(object):
error = False
self.tests = []
xform_name_set = set()
xform_index = 0
for xform in transforms:
if not isinstance(xform, SearchAndReplace):
try:
@ -1303,6 +1564,68 @@ class AlgebraicPass(object):
print("{}".format(xform.search.cond), file=sys.stderr)
error = True
if not build_tests:
continue
if expression_is_unsupported(xform.search) or expression_is_unsupported(xform.replace):
if xform.test_status != TestStatus.UNSUPPORTED:
print("Transform unsupported for unit testing but not marked as TestStatus.UNSUPPORTED: "
"{} -> {}).".format(str(xform.search),
str(xform.replace)))
error = True
continue
name = get_expression_name(xform.search)
if name in xform_name_set:
name = f"{name}_{xform_index}"
xform_name_set.add(name)
value_comps = defaultdict()
get_value_comps(xform.search, value_comps)
get_value_comps(xform.replace, value_comps)
variable_map = defaultdict()
xform_defs = []
expr_conds = []
fp_math_ctrl = set()
search_def = get_expression_def(
xform.search, "search", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name)
replace_def = get_expression_def(
xform.replace, "replace", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name)
# lowering patterns can lose precision
if expression_is_inexact(xform.search) or (xform.condition != "true" and expression_has_float(xform.search)):
xform_defs.append("exact = false;")
if fp_math_ctrl:
xform_defs.append(
f"fp_math_ctrl = (nir_fp_math_control) (fp_math_ctrl & ~({" | ".join(fp_math_ctrl)}));")
# Do a little setup before our unit_test_assert_eq so some top-level expression conditions pass and
# we get better coverage.
for bits in ["8", "16"]:
if xform.search.cond == f"only_lower_{bits}_bits_used":
xform_defs.append(
f" nir_def *search_{bits} = nir_u2u{bits}(b, {search_def});")
search_def = f"search_{bits}"
xform_defs.append(
f" nir_def *replace_{bits} = nir_u2u{bits}(b, {replace_def});")
replace_def = f"replace_{bits}"
if xform.search.cond == "is_only_used_as_float":
xform_defs.append(
f" nir_def *search_float = nir_fadd_imm(b, {search_def}, 0.0);")
search_def = "search_float"
xform_defs.append(
f" nir_def *replace_float = nir_fadd_imm(b, {replace_def}, 0.0);")
replace_def = "replace_float"
verbose_name = f"{str(xform.search)} -> {str(xform.replace)}"
self.tests.append((name, verbose_name, xform_defs, expr_conds, search_def, replace_def,
TestStatus, xform.test_status))
xform_index += 1
self.automaton = TreeAutomaton(self.xforms)
if error:
@ -1322,6 +1645,19 @@ class AlgebraicPass(object):
itertools=itertools,
params=self.params)
def render_tests(self):
chunk_len = (len(self.tests) + 7) // 8
chunks = []
for i in range(8):
chunks.append(self.tests[(i * chunk_len):((i + 1) * chunk_len)])
return _algebraic_pass_pattern_test_template.render(
pass_name=self.pass_name,
chunks=chunks,
subset=i,
variable_cond=sorted(
self.variable_cond.items(), key=lambda kv: kv[1])
)
# The replacement expression isn't necessarily exact if the search expression is exact.