diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 065810bb393..fe17f39740e 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -1168,8 +1168,12 @@ static const nir_algebraic_table ${pass_name}_table = { }; bool -${pass_name}(nir_shader *shader) -{ +${pass_name}( + nir_shader *shader +% for type, name in params: + , ${type} ${name} +% endfor +) { bool progress = false; bool condition_flags[${len(condition_list)}]; const nir_shader_compiler_options *options = shader->options; @@ -1192,12 +1196,14 @@ ${pass_name}(nir_shader *shader) class AlgebraicPass(object): - def __init__(self, pass_name, transforms): + # params is a list of `("type", "name")` tuples + def __init__(self, pass_name, transforms, params=[]): self.xforms = [] self.opcode_xforms = defaultdict(lambda : []) self.pass_name = pass_name self.expression_cond = {} self.variable_cond = {} + self.params = params error = False @@ -1263,7 +1269,8 @@ class AlgebraicPass(object): expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), get_c_opcode=get_c_opcode, - itertools=itertools) + itertools=itertools, + params=self.params) # The replacement expression isn't necessarily exact if the search expression is exact. def ignore_exact(*expr):