mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-30 00:40:25 +01:00
nir/opt_algebraic_tests: Add an option for generating unit tests
It only emits tests for exact patterns which do not use instructions that drop precision by design. Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39076>
This commit is contained in:
parent
14fafebc1a
commit
f5864ed408
1 changed files with 338 additions and 2 deletions
|
|
@ -30,6 +30,14 @@ import re
|
|||
import traceback
|
||||
|
||||
from nir_opcodes import opcodes, type_sizes
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TestStatus(Enum):
|
||||
PASS = 0,
|
||||
XFAIL = 1,
|
||||
UNSUPPORTED = 2,
|
||||
|
||||
|
||||
# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
|
||||
nir_search_max_comm_ops = 8
|
||||
|
|
@ -319,6 +327,12 @@ _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
|
|||
r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
|
||||
r"$")
|
||||
|
||||
swizzles = {'x': 0, 'y': 1, 'z': 2, 'w': 3,
|
||||
'a': 0, 'b': 1, 'c': 2, 'd': 3,
|
||||
'e': 4, 'f': 5, 'g': 6, 'h': 7,
|
||||
'i': 8, 'j': 9, 'k': 10, 'l': 11,
|
||||
'm': 12, 'n': 13, 'o': 14, 'p': 15}
|
||||
|
||||
|
||||
class Variable(Value):
|
||||
def __init__(self, val, name, varset, algebraic_pass):
|
||||
|
|
@ -338,6 +352,7 @@ class Variable(Value):
|
|||
assert self.var_name != 'False'
|
||||
|
||||
self.is_constant = m.group('const') is not None
|
||||
self.cond = m.group('cond')
|
||||
self.cond_index = get_cond_index(
|
||||
algebraic_pass.variable_cond, m.group('cond'))
|
||||
self.required_type = m.group('type')
|
||||
|
|
@ -826,6 +841,11 @@ class SearchAndReplace(object):
|
|||
else:
|
||||
self.condition = 'true'
|
||||
|
||||
if len(transform) > 3:
|
||||
self.test_status = transform[3]
|
||||
else:
|
||||
self.test_status = TestStatus.PASS
|
||||
|
||||
if self.condition not in condition_list:
|
||||
condition_list.append(self.condition)
|
||||
self.condition_index = condition_list.index(self.condition)
|
||||
|
|
@ -1145,7 +1165,7 @@ ${xform.replace.render(cache)}
|
|||
};
|
||||
|
||||
% if expression_cond:
|
||||
static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
|
||||
UNUSED static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
|
||||
% for cond in expression_cond:
|
||||
${cond[0]},
|
||||
% endfor
|
||||
|
|
@ -1242,10 +1262,246 @@ ${pass_name}(
|
|||
}
|
||||
""")
|
||||
|
||||
_algebraic_pass_pattern_test_template = mako.template.Template("""
|
||||
#include <math.h>
|
||||
|
||||
#include "tests/nir_algebraic_pattern_test.h"
|
||||
#include "nir_search_helpers.h"
|
||||
|
||||
% if variable_cond:
|
||||
UNUSED static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
|
||||
% for cond in variable_cond:
|
||||
${cond[0]},
|
||||
% endfor
|
||||
};
|
||||
% endif
|
||||
|
||||
class ${pass_name}_pattern_test : public nir_algebraic_pattern_test {
|
||||
protected:
|
||||
${pass_name}_pattern_test()
|
||||
: nir_algebraic_pattern_test("${pass_name}_pattern_test")
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
% for subset, chunk in enumerate(chunks):
|
||||
#if SUBSET == ${subset}
|
||||
|
||||
% for test_name, verbose_name, xform_defs, expr_conds, search_def, replace_def, test_status, expected_result in chunk:
|
||||
TEST_F(${pass_name}_pattern_test, ${test_name})
|
||||
{
|
||||
b->shader->info.name = "${verbose_name}";
|
||||
% if expected_result == test_status.XFAIL:
|
||||
expected_result = XFAIL;
|
||||
% elif expected_result == test_status.UNSUPPORTED:
|
||||
expected_result = UNSUPPORTED;
|
||||
% endif
|
||||
% for xform_def in xform_defs:
|
||||
${xform_def}
|
||||
% endfor
|
||||
nir_unit_test_assert_eq(b, ${search_def}, ${replace_def});
|
||||
% for cond in expr_conds:
|
||||
${cond}
|
||||
% endfor
|
||||
validate_pattern();
|
||||
}
|
||||
|
||||
% endfor
|
||||
|
||||
#endif /* SUBSET == ${subset} */
|
||||
|
||||
% endfor
|
||||
""")
|
||||
|
||||
|
||||
def expression_has_float(expr):
|
||||
if isinstance(expr, (Variable, Constant)):
|
||||
return False
|
||||
|
||||
if any(expression_has_float(src) for src in expr.sources):
|
||||
return True
|
||||
|
||||
opcode = expr.opcode
|
||||
if opcode in conv_opcode_types:
|
||||
opcode += "32"
|
||||
|
||||
return "float" in opcodes[opcode].output_type or any("float" in src for src in opcodes[opcode].input_types)
|
||||
|
||||
|
||||
def expression_is_unsupported(expr):
|
||||
if isinstance(expr, Constant) or isinstance(expr, Variable):
|
||||
return False
|
||||
|
||||
if any(expression_is_unsupported(src) for src in expr.sources):
|
||||
return True
|
||||
|
||||
if expr.swizzle != -1:
|
||||
return True
|
||||
|
||||
broken_opcodes = [
|
||||
# medium precision means that the compiler can do whatever it wants which makes it unsuitable for testing.
|
||||
"f2fmp", "i2imp", "f2imp", "f2ump", "i2fmp", "u2fmp",
|
||||
# _replicated OPs do not have nir_builder functions.
|
||||
"fdot2_replicated", "fdot3_replicated", "fdot4_replicated", "fdph_replicated",
|
||||
|
||||
# The tests do not validate patterns with those opcodes correctly.
|
||||
"imad24_ir3", "imul24_relaxed", "umad24_relaxed", "umul24_relaxed",
|
||||
"udiv_aligned_4",
|
||||
]
|
||||
|
||||
if expr.opcode in broken_opcodes:
|
||||
return True
|
||||
|
||||
# These are just too slow to evaluate on our qemu cross builds. The
|
||||
# 2/3 cases should cover our patterns well enough.
|
||||
if re.match(r"(bany|ball).*equal(4|8|16)", expr.opcode):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def expression_is_inexact(expr):
|
||||
if isinstance(expr, (Variable, Constant)):
|
||||
return False
|
||||
|
||||
if any(expression_is_inexact(src) for src in expr.sources):
|
||||
return True
|
||||
|
||||
return expr.inexact or expr.contract
|
||||
|
||||
|
||||
def get_expression_name(expr):
|
||||
name = expr.opcode
|
||||
|
||||
for src in expr.sources:
|
||||
if isinstance(src, Expression):
|
||||
name += "_" + get_expression_name(src)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def get_value_comps(expr, value_comps, num_components=1):
|
||||
if isinstance(expr, Variable):
|
||||
value_comps[expr.index] = max(value_comps.get(
|
||||
expr.index, num_components), num_components)
|
||||
|
||||
if expr.swiz is not None:
|
||||
for comp in [swizzles[c] for c in expr.swiz[1:]]:
|
||||
value_comps[expr.index] = max(
|
||||
value_comps[expr.index], comp + 1)
|
||||
return
|
||||
|
||||
if isinstance(expr, Constant):
|
||||
value_comps[expr] = max(value_comps.get(
|
||||
expr, num_components), num_components)
|
||||
return
|
||||
|
||||
opcode = expr.opcode
|
||||
if opcode in conv_opcode_types:
|
||||
opcode += "32"
|
||||
|
||||
for src_num_components, src in zip(opcodes[opcode].input_sizes, expr.sources):
|
||||
src_num_components = max(src_num_components, 1)
|
||||
get_value_comps(src, value_comps, src_num_components)
|
||||
|
||||
|
||||
def get_expression_def(expr, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name):
|
||||
bit_size = 32 if expr.c_bit_size <= 0 else expr.c_bit_size
|
||||
|
||||
if isinstance(expr, Variable):
|
||||
if expr.index not in variable_map:
|
||||
def_name = f"{name}{len(defs)}"
|
||||
num_components = value_comps[expr.index]
|
||||
|
||||
if expr.required_type == "bool":
|
||||
defs.append(
|
||||
f"nir_def *{def_name} = nir_b2b{bit_size}(b, nir_unit_test_uniform_input(b, {num_components}, 1, {expr.index}));")
|
||||
else:
|
||||
defs.append(
|
||||
f"nir_def *{def_name} = nir_unit_test_uniform_input(b, {num_components}, {bit_size}, {expr.index});")
|
||||
|
||||
variable_map[expr.index] = def_name
|
||||
else:
|
||||
def_name = variable_map[expr.index]
|
||||
|
||||
if expr.swiz is not None:
|
||||
swizzle_name = f"{name}{len(defs)}_swizzle"
|
||||
defs.append(
|
||||
f"uint32_t {swizzle_name}[{len(expr.swiz) - 1}] = {expr.swizzle()};")
|
||||
def_name = f"nir_swizzle(b, {def_name}, {swizzle_name}, {len(expr.swiz) - 1})"
|
||||
|
||||
return def_name
|
||||
|
||||
if isinstance(expr, Constant):
|
||||
def_name = f"{name}{len(defs)}"
|
||||
|
||||
if isinstance(expr.value, bool):
|
||||
defs.append(
|
||||
f"nir_def *{def_name} = nir_imm_{"true" if expr.value else "false"}(b);")
|
||||
elif isinstance(expr.value, int):
|
||||
defs.append(
|
||||
f"nir_def *{def_name} = nir_imm_intN_t(b, {expr.value}llu, {bit_size});")
|
||||
elif isinstance(expr.value, float):
|
||||
value = str(expr.value)
|
||||
if value == "nan":
|
||||
value = "NAN"
|
||||
elif value == "-nan":
|
||||
value = "-NAN"
|
||||
elif value == "inf":
|
||||
value = "INFINITY"
|
||||
elif value == "-inf":
|
||||
value = "-INFINITY"
|
||||
defs.append(
|
||||
f"nir_def *{def_name} = nir_imm_floatN_t(b, {value}, {bit_size});")
|
||||
|
||||
if value_comps[expr] != 1:
|
||||
comps = ", ".join(def_name for c in range(0, value_comps[expr]))
|
||||
defs.append(
|
||||
f"nir_def *{def_name}_tmp[{value_comps[expr]}] = {{{comps}}}; {def_name} = nir_vec(b, {def_name}_tmp, {value_comps[expr]});")
|
||||
|
||||
return def_name
|
||||
|
||||
srcs = [get_expression_def(src, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name)
|
||||
for src in expr.sources]
|
||||
|
||||
opcode = expr.opcode
|
||||
if opcode in conv_opcode_types:
|
||||
opcode += str(bit_size)
|
||||
|
||||
def_name = f"{name}{len(defs)}"
|
||||
|
||||
if expr.nsz:
|
||||
fp_math_ctrl.add("nir_fp_preserve_signed_zero")
|
||||
if expr.nnan:
|
||||
fp_math_ctrl.add("nir_fp_preserve_nan")
|
||||
if expr.ninf:
|
||||
fp_math_ctrl.add("nir_fp_preserve_inf")
|
||||
|
||||
defs.append(f"nir_def *{def_name} = nir_{opcode}(b, {", ".join(srcs)});")
|
||||
|
||||
if expr.cond != None and isinstance(expr, Expression):
|
||||
# These do not matter for correctness
|
||||
if expr.cond not in ["is_used_once", "is_not_used_by_if"]:
|
||||
expr_conds.append(f"if (!{expr.cond}(nir_def_as_alu({def_name})))")
|
||||
expr_conds.append(
|
||||
f" expression_cond_failed = \"{expr.cond} for {def_name}\";")
|
||||
|
||||
for src_index, expr in enumerate(expr.sources):
|
||||
# We don't include not-const checks, since we implement our test
|
||||
# evaluation by using load_consts for inputs, while the intent of
|
||||
# patterns using them is "don't do this transformation when there are
|
||||
# actual constants rather than variable values here." Similarly for
|
||||
# is_fmul in fma distribution
|
||||
if isinstance(expr, Variable) and expr.cond_index != -1 and expr.cond not in {"(is_not_const)", "(is_not_const_zero)", "(is_fmul)", "(is_not_const_and_not_fsign)"}:
|
||||
defs.append(
|
||||
f"variable_conds.push_back(nir_algebraic_pattern_test_variable_cond(nir_def_as_alu({def_name}), {src_index}, {pass_name}_variable_cond[{expr.cond_index}]));")
|
||||
|
||||
return def_name
|
||||
|
||||
|
||||
class AlgebraicPass(object):
|
||||
# params is a list of `("type", "name")` tuples
|
||||
def __init__(self, pass_name, transforms, params=[]):
|
||||
def __init__(self, pass_name, transforms, params=[], build_tests=False):
|
||||
self.xforms = []
|
||||
self.opcode_xforms = defaultdict(lambda: [])
|
||||
self.pass_name = pass_name
|
||||
|
|
@ -1255,6 +1511,11 @@ class AlgebraicPass(object):
|
|||
|
||||
error = False
|
||||
|
||||
self.tests = []
|
||||
xform_name_set = set()
|
||||
|
||||
xform_index = 0
|
||||
|
||||
for xform in transforms:
|
||||
if not isinstance(xform, SearchAndReplace):
|
||||
try:
|
||||
|
|
@ -1303,6 +1564,68 @@ class AlgebraicPass(object):
|
|||
print("{}".format(xform.search.cond), file=sys.stderr)
|
||||
error = True
|
||||
|
||||
if not build_tests:
|
||||
continue
|
||||
|
||||
if expression_is_unsupported(xform.search) or expression_is_unsupported(xform.replace):
|
||||
if xform.test_status != TestStatus.UNSUPPORTED:
|
||||
print("Transform unsupported for unit testing but not marked as TestStatus.UNSUPPORTED: "
|
||||
"{} -> {}).".format(str(xform.search),
|
||||
str(xform.replace)))
|
||||
error = True
|
||||
continue
|
||||
|
||||
name = get_expression_name(xform.search)
|
||||
if name in xform_name_set:
|
||||
name = f"{name}_{xform_index}"
|
||||
|
||||
xform_name_set.add(name)
|
||||
|
||||
value_comps = defaultdict()
|
||||
get_value_comps(xform.search, value_comps)
|
||||
get_value_comps(xform.replace, value_comps)
|
||||
|
||||
variable_map = defaultdict()
|
||||
xform_defs = []
|
||||
expr_conds = []
|
||||
fp_math_ctrl = set()
|
||||
search_def = get_expression_def(
|
||||
xform.search, "search", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name)
|
||||
replace_def = get_expression_def(
|
||||
xform.replace, "replace", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name)
|
||||
|
||||
# lowering patterns can lose precision
|
||||
if expression_is_inexact(xform.search) or (xform.condition != "true" and expression_has_float(xform.search)):
|
||||
xform_defs.append("exact = false;")
|
||||
|
||||
if fp_math_ctrl:
|
||||
xform_defs.append(
|
||||
f"fp_math_ctrl = (nir_fp_math_control) (fp_math_ctrl & ~({" | ".join(fp_math_ctrl)}));")
|
||||
|
||||
# Do a little setup before our unit_test_assert_eq so some top-level expression conditions pass and
|
||||
# we get better coverage.
|
||||
for bits in ["8", "16"]:
|
||||
if xform.search.cond == f"only_lower_{bits}_bits_used":
|
||||
xform_defs.append(
|
||||
f" nir_def *search_{bits} = nir_u2u{bits}(b, {search_def});")
|
||||
search_def = f"search_{bits}"
|
||||
xform_defs.append(
|
||||
f" nir_def *replace_{bits} = nir_u2u{bits}(b, {replace_def});")
|
||||
replace_def = f"replace_{bits}"
|
||||
if xform.search.cond == "is_only_used_as_float":
|
||||
xform_defs.append(
|
||||
f" nir_def *search_float = nir_fadd_imm(b, {search_def}, 0.0);")
|
||||
search_def = "search_float"
|
||||
xform_defs.append(
|
||||
f" nir_def *replace_float = nir_fadd_imm(b, {replace_def}, 0.0);")
|
||||
replace_def = "replace_float"
|
||||
|
||||
verbose_name = f"{str(xform.search)} -> {str(xform.replace)}"
|
||||
self.tests.append((name, verbose_name, xform_defs, expr_conds, search_def, replace_def,
|
||||
TestStatus, xform.test_status))
|
||||
|
||||
xform_index += 1
|
||||
|
||||
self.automaton = TreeAutomaton(self.xforms)
|
||||
|
||||
if error:
|
||||
|
|
@ -1322,6 +1645,19 @@ class AlgebraicPass(object):
|
|||
itertools=itertools,
|
||||
params=self.params)
|
||||
|
||||
def render_tests(self):
|
||||
chunk_len = (len(self.tests) + 7) // 8
|
||||
chunks = []
|
||||
for i in range(8):
|
||||
chunks.append(self.tests[(i * chunk_len):((i + 1) * chunk_len)])
|
||||
return _algebraic_pass_pattern_test_template.render(
|
||||
pass_name=self.pass_name,
|
||||
chunks=chunks,
|
||||
subset=i,
|
||||
variable_cond=sorted(
|
||||
self.variable_cond.items(), key=lambda kv: kv[1])
|
||||
)
|
||||
|
||||
# The replacement expression isn't necessarily exact if the search expression is exact.
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue