mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-09 04:38:03 +02:00
nir/algebraic: Add a bit-size validator
This commit adds a validator that ensures that all expressions passed through nir_algebraic are 100% non-ambiguous as far as bit-sizes are concerned. This way it's a compile-time error rather than a hard-to-trace C exception some time later. Reviewed-by: Samuel Iglesias Gonsálvez <siglesias@igalia.com>
This commit is contained in:
parent
8a3e344180
commit
e0806930ad
1 changed files with 270 additions and 0 deletions
|
|
@ -33,6 +33,19 @@ import mako.template
|
|||
import re
|
||||
import traceback
|
||||
|
||||
from nir_opcodes import opcodes
|
||||
|
||||
_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
|
||||
|
||||
def type_bits(type_str):
|
||||
m = _type_re.match(type_str)
|
||||
assert m.group('type')
|
||||
|
||||
if m.group('bits') is None:
|
||||
return 0
|
||||
else:
|
||||
return int(m.group('bits'))
|
||||
|
||||
# Represents a set of variables, each with a unique id
|
||||
class VarSet(object):
|
||||
def __init__(self):
|
||||
|
|
@ -188,6 +201,261 @@ class Expression(Value):
|
|||
srcs = "\n".join(src.render() for src in self.sources)
|
||||
return srcs + super(Expression, self).render()
|
||||
|
||||
class IntEquivalenceRelation(object):
|
||||
"""A class representing an equivalence relation on integers.
|
||||
|
||||
Each integer has a canonical form which is the maximum integer to which it
|
||||
is equivalent. Two integers are equivalent precisely when they have the
|
||||
same canonical form.
|
||||
|
||||
The convention of maximum is explicitly chosen to make using it in
|
||||
BitSizeValidator easier because it means that an actual bit_size (if any)
|
||||
will always be the canonical form.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._remap = {}
|
||||
|
||||
def get_canonical(self, x):
|
||||
"""Get the canonical integer corresponding to x."""
|
||||
if x in self._remap:
|
||||
return self.get_canonical(self._remap[x])
|
||||
else:
|
||||
return x
|
||||
|
||||
def add_equiv(self, a, b):
|
||||
"""Add an equivalence and return the canonical form."""
|
||||
c = max(self.get_canonical(a), self.get_canonical(b))
|
||||
if a != c:
|
||||
assert a < c
|
||||
self._remap[a] = c
|
||||
|
||||
if b != c:
|
||||
assert b < c
|
||||
self._remap[b] = c
|
||||
|
||||
return c
|
||||
|
||||
class BitSizeValidator(object):
|
||||
"""A class for validating bit sizes of expressions.
|
||||
|
||||
NIR supports multiple bit-sizes on expressions in order to handle things
|
||||
such as fp64. The source and destination of every ALU operation is
|
||||
assigned a type and that type may or may not specify a bit size. Sources
|
||||
and destinations whose type does not specify a bit size are considered
|
||||
"unsized" and automatically take on the bit size of the corresponding
|
||||
register or SSA value. NIR has two simple rules for bit sizes that are
|
||||
validated by nir_validator:
|
||||
|
||||
1) A given SSA def or register has a single bit size that is respected by
|
||||
everything that reads from it or writes to it.
|
||||
|
||||
2) The bit sizes of all unsized inputs/outputs on any given ALU
|
||||
instruction must match. They need not match the sized inputs or
|
||||
outputs but they must match each other.
|
||||
|
||||
In order to keep nir_algebraic relatively simple and easy-to-use,
|
||||
nir_search supports a type of bit-size inference based on the two rules
|
||||
above. This is similar to type inference in many common programming
|
||||
languages. If, for instance, you are constructing an add operation and you
|
||||
know the second source is 16-bit, then you know that the other source and
|
||||
the destination must also be 16-bit. There are, however, cases where this
|
||||
inference can be ambiguous or contradictory. Consider, for instance, the
|
||||
following transformation:
|
||||
|
||||
(('usub_borrow', a, b), ('b2i', ('ult', a, b)))
|
||||
|
||||
This transformation can potentially cause a problem because usub_borrow is
|
||||
well-defined for any bit-size of integer. However, b2i always generates a
|
||||
32-bit result so it could end up replacing a 64-bit expression with one
|
||||
that takes two 64-bit values and produces a 32-bit value. As another
|
||||
example, consider this expression:
|
||||
|
||||
(('bcsel', a, b, 0), ('iand', a, b))
|
||||
|
||||
In this case, in the search expression a must be 32-bit but b can
|
||||
potentially have any bit size. If we had a 64-bit b value, we would end up
|
||||
trying to and a 32-bit value with a 64-bit value which would be invalid
|
||||
|
||||
This class solves that problem by providing a validation layer that proves
|
||||
that a given search-and-replace operation is 100% well-defined before we
|
||||
generate any code. This ensures that bugs are caught at compile time
|
||||
rather than at run time.
|
||||
|
||||
The basic operation of the validator is very similar to the bitsize_tree in
|
||||
nir_search only a little more subtle. Instead of simply tracking bit
|
||||
sizes, it tracks "bit classes" where each class is represented by an
|
||||
integer. A value of 0 means we don't know anything yet, positive values
|
||||
are actual bit-sizes, and negative values are used to track equivalence
|
||||
classes of sizes that must be the same but have yet to receive an actual
|
||||
size. The first stage uses the bitsize_tree algorithm to assign bit
|
||||
classes to each variable. If it ever comes across an inconsistency, it
|
||||
assert-fails. Then the second stage uses that information to prove that
|
||||
the resulting expression can always validly be constructed.
|
||||
"""
|
||||
|
||||
def __init__(self, varset):
|
||||
self._num_classes = 0
|
||||
self._var_classes = [0] * len(varset.names)
|
||||
self._class_relation = IntEquivalenceRelation()
|
||||
|
||||
def validate(self, search, replace):
|
||||
dst_class = self._propagate_bit_size_up(search)
|
||||
if dst_class == 0:
|
||||
dst_class = self._new_class()
|
||||
self._propagate_bit_class_down(search, dst_class)
|
||||
|
||||
validate_dst_class = self._validate_bit_class_up(replace)
|
||||
assert validate_dst_class == 0 or validate_dst_class == dst_class
|
||||
self._validate_bit_class_down(replace, dst_class)
|
||||
|
||||
def _new_class(self):
|
||||
self._num_classes += 1
|
||||
return -self._num_classes
|
||||
|
||||
def _set_var_bit_class(self, var_id, bit_class):
|
||||
assert bit_class != 0
|
||||
var_class = self._var_classes[var_id]
|
||||
if var_class == 0:
|
||||
self._var_classes[var_id] = bit_class
|
||||
else:
|
||||
canon_class = self._class_relation.get_canonical(var_class)
|
||||
assert canon_class < 0 or canon_class == bit_class
|
||||
var_class = self._class_relation.add_equiv(var_class, bit_class)
|
||||
self._var_classes[var_id] = var_class
|
||||
|
||||
def _get_var_bit_class(self, var_id):
|
||||
return self._class_relation.get_canonical(self._var_classes[var_id])
|
||||
|
||||
def _propagate_bit_size_up(self, val):
|
||||
if isinstance(val, (Constant, Variable)):
|
||||
return val.bit_size
|
||||
|
||||
elif isinstance(val, Expression):
|
||||
nir_op = opcodes[val.opcode]
|
||||
val.common_size = 0
|
||||
for i in range(nir_op.num_inputs):
|
||||
src_bits = self._propagate_bit_size_up(val.sources[i])
|
||||
if src_bits == 0:
|
||||
continue
|
||||
|
||||
src_type_bits = type_bits(nir_op.input_types[i])
|
||||
if src_type_bits != 0:
|
||||
assert src_bits == src_type_bits
|
||||
else:
|
||||
assert val.common_size == 0 or src_bits == val.common_size
|
||||
val.common_size = src_bits
|
||||
|
||||
dst_type_bits = type_bits(nir_op.output_type)
|
||||
if dst_type_bits != 0:
|
||||
assert val.bit_size == 0 or val.bit_size == dst_type_bits
|
||||
return dst_type_bits
|
||||
else:
|
||||
if val.common_size != 0:
|
||||
assert val.bit_size == 0 or val.bit_size == val.common_size
|
||||
else:
|
||||
val.common_size = val.bit_size
|
||||
return val.common_size
|
||||
|
||||
def _propagate_bit_class_down(self, val, bit_class):
|
||||
if isinstance(val, Constant):
|
||||
assert val.bit_size == 0 or val.bit_size == bit_class
|
||||
|
||||
elif isinstance(val, Variable):
|
||||
assert val.bit_size == 0 or val.bit_size == bit_class
|
||||
self._set_var_bit_class(val.index, bit_class)
|
||||
|
||||
elif isinstance(val, Expression):
|
||||
nir_op = opcodes[val.opcode]
|
||||
dst_type_bits = type_bits(nir_op.output_type)
|
||||
if dst_type_bits != 0:
|
||||
assert bit_class == 0 or bit_class == dst_type_bits
|
||||
else:
|
||||
assert val.common_size == 0 or val.common_size == bit_class
|
||||
val.common_size = bit_class
|
||||
|
||||
if val.common_size:
|
||||
common_class = val.common_size
|
||||
elif nir_op.num_inputs:
|
||||
# If we got here then we have no idea what the actual size is.
|
||||
# Instead, we use a generic class
|
||||
common_class = self._new_class()
|
||||
|
||||
for i in range(nir_op.num_inputs):
|
||||
src_type_bits = type_bits(nir_op.input_types[i])
|
||||
if src_type_bits != 0:
|
||||
self._propagate_bit_class_down(val.sources[i], src_type_bits)
|
||||
else:
|
||||
self._propagate_bit_class_down(val.sources[i], common_class)
|
||||
|
||||
def _validate_bit_class_up(self, val):
|
||||
if isinstance(val, Constant):
|
||||
return val.bit_size
|
||||
|
||||
elif isinstance(val, Variable):
|
||||
var_class = self._get_var_bit_class(val.index)
|
||||
# By the time we get to validation, every variable should have a class
|
||||
assert var_class != 0
|
||||
|
||||
# If we have an explicit size provided by the user, the variable
|
||||
# *must* exactly match the search. It cannot be implicitly sized
|
||||
# because otherwise we could end up with a conflict at runtime.
|
||||
assert val.bit_size == 0 or val.bit_size == var_class
|
||||
|
||||
return var_class
|
||||
|
||||
elif isinstance(val, Expression):
|
||||
nir_op = opcodes[val.opcode]
|
||||
val.common_class = 0
|
||||
for i in range(nir_op.num_inputs):
|
||||
src_class = self._validate_bit_class_up(val.sources[i])
|
||||
if src_class == 0:
|
||||
continue
|
||||
|
||||
src_type_bits = type_bits(nir_op.input_types[i])
|
||||
if src_type_bits != 0:
|
||||
assert src_class == src_type_bits
|
||||
else:
|
||||
assert val.common_class == 0 or src_class == val.common_class
|
||||
val.common_class = src_class
|
||||
|
||||
dst_type_bits = type_bits(nir_op.output_type)
|
||||
if dst_type_bits != 0:
|
||||
assert val.bit_size == 0 or val.bit_size == dst_type_bits
|
||||
return dst_type_bits
|
||||
else:
|
||||
if val.common_class != 0:
|
||||
assert val.bit_size == 0 or val.bit_size == val.common_class
|
||||
else:
|
||||
val.common_class = val.bit_size
|
||||
return val.common_class
|
||||
|
||||
def _validate_bit_class_down(self, val, bit_class):
|
||||
# At this point, everything *must* have a bit class. Otherwise, we have
|
||||
# a value we don't know how to define.
|
||||
assert bit_class != 0
|
||||
|
||||
if isinstance(val, Constant):
|
||||
assert val.bit_size == 0 or val.bit_size == bit_class
|
||||
|
||||
elif isinstance(val, Variable):
|
||||
assert val.bit_size == 0 or val.bit_size == bit_class
|
||||
|
||||
elif isinstance(val, Expression):
|
||||
nir_op = opcodes[val.opcode]
|
||||
dst_type_bits = type_bits(nir_op.output_type)
|
||||
if dst_type_bits != 0:
|
||||
assert bit_class == dst_type_bits
|
||||
else:
|
||||
assert val.common_class == 0 or val.common_class == bit_class
|
||||
val.common_class = bit_class
|
||||
|
||||
for i in range(nir_op.num_inputs):
|
||||
src_type_bits = type_bits(nir_op.input_types[i])
|
||||
if src_type_bits != 0:
|
||||
self._validate_bit_class_down(val.sources[i], src_type_bits)
|
||||
else:
|
||||
self._validate_bit_class_down(val.sources[i], val.common_class)
|
||||
|
||||
_optimization_ids = itertools.count()
|
||||
|
||||
condition_list = ['true']
|
||||
|
|
@ -220,6 +488,8 @@ class SearchAndReplace(object):
|
|||
else:
|
||||
self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
|
||||
|
||||
BitSizeValidator(varset).validate(self.search, self.replace)
|
||||
|
||||
_algebraic_pass_template = mako.template.Template("""
|
||||
#include "nir.h"
|
||||
#include "nir_search.h"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue