diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 281d564bdb7..031bdafa7b0 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -30,6 +30,14 @@ import re import traceback from nir_opcodes import opcodes, type_sizes +from enum import Enum + + +class TestStatus(Enum): + PASS = 0, + XFAIL = 1, + UNSUPPORTED = 2, + # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c nir_search_max_comm_ops = 8 @@ -319,6 +327,12 @@ _var_name_re = re.compile(r"(?P#)?(?P\w+)" r"(?P\.[xyzwabcdefghijklmnop]+)?" r"$") +swizzles = {'x': 0, 'y': 1, 'z': 2, 'w': 3, + 'a': 0, 'b': 1, 'c': 2, 'd': 3, + 'e': 4, 'f': 5, 'g': 6, 'h': 7, + 'i': 8, 'j': 9, 'k': 10, 'l': 11, + 'm': 12, 'n': 13, 'o': 14, 'p': 15} + class Variable(Value): def __init__(self, val, name, varset, algebraic_pass): @@ -338,6 +352,7 @@ class Variable(Value): assert self.var_name != 'False' self.is_constant = m.group('const') is not None + self.cond = m.group('cond') self.cond_index = get_cond_index( algebraic_pass.variable_cond, m.group('cond')) self.required_type = m.group('type') @@ -826,6 +841,11 @@ class SearchAndReplace(object): else: self.condition = 'true' + if len(transform) > 3: + self.test_status = transform[3] + else: + self.test_status = TestStatus.PASS + if self.condition not in condition_list: condition_list.append(self.condition) self.condition_index = condition_list.index(self.condition) @@ -1145,7 +1165,7 @@ ${xform.replace.render(cache)} }; % if expression_cond: -static const nir_search_expression_cond ${pass_name}_expression_cond[] = { +UNUSED static const nir_search_expression_cond ${pass_name}_expression_cond[] = { % for cond in expression_cond: ${cond[0]}, % endfor @@ -1242,10 +1262,246 @@ ${pass_name}( } """) +_algebraic_pass_pattern_test_template = mako.template.Template(""" +#include + +#include "tests/nir_algebraic_pattern_test.h" +#include "nir_search_helpers.h" + +% if variable_cond: +UNUSED static const nir_search_variable_cond ${pass_name}_variable_cond[] = { +% for cond in variable_cond: + ${cond[0]}, +% endfor +}; +% endif + +class ${pass_name}_pattern_test : public nir_algebraic_pattern_test { +protected: + ${pass_name}_pattern_test() + : nir_algebraic_pattern_test("${pass_name}_pattern_test") + { + } +}; + +% for subset, chunk in enumerate(chunks): +#if SUBSET == ${subset} + +% for test_name, verbose_name, xform_defs, expr_conds, search_def, replace_def, test_status, expected_result in chunk: +TEST_F(${pass_name}_pattern_test, ${test_name}) +{ + b->shader->info.name = "${verbose_name}"; +% if expected_result == test_status.XFAIL: + expected_result = XFAIL; +% elif expected_result == test_status.UNSUPPORTED: + expected_result = UNSUPPORTED; +% endif +% for xform_def in xform_defs: + ${xform_def} +% endfor + nir_unit_test_assert_eq(b, ${search_def}, ${replace_def}); +% for cond in expr_conds: + ${cond} +% endfor + validate_pattern(); +} + +% endfor + +#endif /* SUBSET == ${subset} */ + +% endfor +""") + + +def expression_has_float(expr): + if isinstance(expr, (Variable, Constant)): + return False + + if any(expression_has_float(src) for src in expr.sources): + return True + + opcode = expr.opcode + if opcode in conv_opcode_types: + opcode += "32" + + return "float" in opcodes[opcode].output_type or any("float" in src for src in opcodes[opcode].input_types) + + +def expression_is_unsupported(expr): + if isinstance(expr, Constant) or isinstance(expr, Variable): + return False + + if any(expression_is_unsupported(src) for src in expr.sources): + return True + + if expr.swizzle != -1: + return True + + broken_opcodes = [ + # medium precision means that the compiler can do whatever it wants which makes it unsuitable for testing. + "f2fmp", "i2imp", "f2imp", "f2ump", "i2fmp", "u2fmp", + # _replicated OPs do not have nir_builder functions. + "fdot2_replicated", "fdot3_replicated", "fdot4_replicated", "fdph_replicated", + + # The tests do not validate patterns with those opcodes correctly. + "imad24_ir3", "imul24_relaxed", "umad24_relaxed", "umul24_relaxed", + "udiv_aligned_4", + ] + + if expr.opcode in broken_opcodes: + return True + + # These are just too slow to evaluate on our qemu cross builds. The + # 2/3 cases should cover our patterns well enough. + if re.match(r"(bany|ball).*equal(4|8|16)", expr.opcode): + return True + + return False + + +def expression_is_inexact(expr): + if isinstance(expr, (Variable, Constant)): + return False + + if any(expression_is_inexact(src) for src in expr.sources): + return True + + return expr.inexact or expr.contract + + +def get_expression_name(expr): + name = expr.opcode + + for src in expr.sources: + if isinstance(src, Expression): + name += "_" + get_expression_name(src) + + return name + + +def get_value_comps(expr, value_comps, num_components=1): + if isinstance(expr, Variable): + value_comps[expr.index] = max(value_comps.get( + expr.index, num_components), num_components) + + if expr.swiz is not None: + for comp in [swizzles[c] for c in expr.swiz[1:]]: + value_comps[expr.index] = max( + value_comps[expr.index], comp + 1) + return + + if isinstance(expr, Constant): + value_comps[expr] = max(value_comps.get( + expr, num_components), num_components) + return + + opcode = expr.opcode + if opcode in conv_opcode_types: + opcode += "32" + + for src_num_components, src in zip(opcodes[opcode].input_sizes, expr.sources): + src_num_components = max(src_num_components, 1) + get_value_comps(src, value_comps, src_num_components) + + +def get_expression_def(expr, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name): + bit_size = 32 if expr.c_bit_size <= 0 else expr.c_bit_size + + if isinstance(expr, Variable): + if expr.index not in variable_map: + def_name = f"{name}{len(defs)}" + num_components = value_comps[expr.index] + + if expr.required_type == "bool": + defs.append( + f"nir_def *{def_name} = nir_b2b{bit_size}(b, nir_unit_test_uniform_input(b, {num_components}, 1, {expr.index}));") + else: + defs.append( + f"nir_def *{def_name} = nir_unit_test_uniform_input(b, {num_components}, {bit_size}, {expr.index});") + + variable_map[expr.index] = def_name + else: + def_name = variable_map[expr.index] + + if expr.swiz is not None: + swizzle_name = f"{name}{len(defs)}_swizzle" + defs.append( + f"uint32_t {swizzle_name}[{len(expr.swiz) - 1}] = {expr.swizzle()};") + def_name = f"nir_swizzle(b, {def_name}, {swizzle_name}, {len(expr.swiz) - 1})" + + return def_name + + if isinstance(expr, Constant): + def_name = f"{name}{len(defs)}" + + if isinstance(expr.value, bool): + defs.append( + f"nir_def *{def_name} = nir_imm_{"true" if expr.value else "false"}(b);") + elif isinstance(expr.value, int): + defs.append( + f"nir_def *{def_name} = nir_imm_intN_t(b, {expr.value}llu, {bit_size});") + elif isinstance(expr.value, float): + value = str(expr.value) + if value == "nan": + value = "NAN" + elif value == "-nan": + value = "-NAN" + elif value == "inf": + value = "INFINITY" + elif value == "-inf": + value = "-INFINITY" + defs.append( + f"nir_def *{def_name} = nir_imm_floatN_t(b, {value}, {bit_size});") + + if value_comps[expr] != 1: + comps = ", ".join(def_name for c in range(0, value_comps[expr])) + defs.append( + f"nir_def *{def_name}_tmp[{value_comps[expr]}] = {{{comps}}}; {def_name} = nir_vec(b, {def_name}_tmp, {value_comps[expr]});") + + return def_name + + srcs = [get_expression_def(src, name, value_comps, variable_map, defs, expr_conds, fp_math_ctrl, pass_name) + for src in expr.sources] + + opcode = expr.opcode + if opcode in conv_opcode_types: + opcode += str(bit_size) + + def_name = f"{name}{len(defs)}" + + if expr.nsz: + fp_math_ctrl.add("nir_fp_preserve_signed_zero") + if expr.nnan: + fp_math_ctrl.add("nir_fp_preserve_nan") + if expr.ninf: + fp_math_ctrl.add("nir_fp_preserve_inf") + + defs.append(f"nir_def *{def_name} = nir_{opcode}(b, {", ".join(srcs)});") + + if expr.cond != None and isinstance(expr, Expression): + # These do not matter for correctness + if expr.cond not in ["is_used_once", "is_not_used_by_if"]: + expr_conds.append(f"if (!{expr.cond}(nir_def_as_alu({def_name})))") + expr_conds.append( + f" expression_cond_failed = \"{expr.cond} for {def_name}\";") + + for src_index, expr in enumerate(expr.sources): + # We don't include not-const checks, since we implement our test + # evaluation by using load_consts for inputs, while the intent of + # patterns using them is "don't do this transformation when there are + # actual constants rather than variable values here." Similarly for + # is_fmul in fma distribution + if isinstance(expr, Variable) and expr.cond_index != -1 and expr.cond not in {"(is_not_const)", "(is_not_const_zero)", "(is_fmul)", "(is_not_const_and_not_fsign)"}: + defs.append( + f"variable_conds.push_back(nir_algebraic_pattern_test_variable_cond(nir_def_as_alu({def_name}), {src_index}, {pass_name}_variable_cond[{expr.cond_index}]));") + + return def_name + class AlgebraicPass(object): # params is a list of `("type", "name")` tuples - def __init__(self, pass_name, transforms, params=[]): + def __init__(self, pass_name, transforms, params=[], build_tests=False): self.xforms = [] self.opcode_xforms = defaultdict(lambda: []) self.pass_name = pass_name @@ -1255,6 +1511,11 @@ class AlgebraicPass(object): error = False + self.tests = [] + xform_name_set = set() + + xform_index = 0 + for xform in transforms: if not isinstance(xform, SearchAndReplace): try: @@ -1303,6 +1564,68 @@ class AlgebraicPass(object): print("{}".format(xform.search.cond), file=sys.stderr) error = True + if not build_tests: + continue + + if expression_is_unsupported(xform.search) or expression_is_unsupported(xform.replace): + if xform.test_status != TestStatus.UNSUPPORTED: + print("Transform unsupported for unit testing but not marked as TestStatus.UNSUPPORTED: " + "{} -> {}).".format(str(xform.search), + str(xform.replace))) + error = True + continue + + name = get_expression_name(xform.search) + if name in xform_name_set: + name = f"{name}_{xform_index}" + + xform_name_set.add(name) + + value_comps = defaultdict() + get_value_comps(xform.search, value_comps) + get_value_comps(xform.replace, value_comps) + + variable_map = defaultdict() + xform_defs = [] + expr_conds = [] + fp_math_ctrl = set() + search_def = get_expression_def( + xform.search, "search", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name) + replace_def = get_expression_def( + xform.replace, "replace", value_comps, variable_map, xform_defs, expr_conds, fp_math_ctrl, self.pass_name) + + # lowering patterns can lose precision + if expression_is_inexact(xform.search) or (xform.condition != "true" and expression_has_float(xform.search)): + xform_defs.append("exact = false;") + + if fp_math_ctrl: + xform_defs.append( + f"fp_math_ctrl = (nir_fp_math_control) (fp_math_ctrl & ~({" | ".join(fp_math_ctrl)}));") + + # Do a little setup before our unit_test_assert_eq so some top-level expression conditions pass and + # we get better coverage. + for bits in ["8", "16"]: + if xform.search.cond == f"only_lower_{bits}_bits_used": + xform_defs.append( + f" nir_def *search_{bits} = nir_u2u{bits}(b, {search_def});") + search_def = f"search_{bits}" + xform_defs.append( + f" nir_def *replace_{bits} = nir_u2u{bits}(b, {replace_def});") + replace_def = f"replace_{bits}" + if xform.search.cond == "is_only_used_as_float": + xform_defs.append( + f" nir_def *search_float = nir_fadd_imm(b, {search_def}, 0.0);") + search_def = "search_float" + xform_defs.append( + f" nir_def *replace_float = nir_fadd_imm(b, {replace_def}, 0.0);") + replace_def = "replace_float" + + verbose_name = f"{str(xform.search)} -> {str(xform.replace)}" + self.tests.append((name, verbose_name, xform_defs, expr_conds, search_def, replace_def, + TestStatus, xform.test_status)) + + xform_index += 1 + self.automaton = TreeAutomaton(self.xforms) if error: @@ -1322,6 +1645,19 @@ class AlgebraicPass(object): itertools=itertools, params=self.params) + def render_tests(self): + chunk_len = (len(self.tests) + 7) // 8 + chunks = [] + for i in range(8): + chunks.append(self.tests[(i * chunk_len):((i + 1) * chunk_len)]) + return _algebraic_pass_pattern_test_template.render( + pass_name=self.pass_name, + chunks=chunks, + subset=i, + variable_cond=sorted( + self.variable_cond.items(), key=lambda kv: kv[1]) + ) + # The replacement expression isn't necessarily exact if the search expression is exact.