pan/bi: Extend the bi_builder to support type variants correctly

Some opcodes come with both type and size variants. Right now, only the
size is taken into account. Extend the builder to provide wrappers that
take a nir_type in addition to the bitsize.

While at it, fix wrappers taking a compare operator to use the proper
.{i,s,u} variant based on the comparison (equal and non-equal should
use .i, other comparisons should use .{u,s}).

Signed-off-by: Boris Brezillon <boris.brezillon@collabora.com>
Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9520>
This commit is contained in:
Boris Brezillon 2021-03-02 13:08:05 +01:00 committed by Marge Bot
parent 0113a0a1ee
commit 3c7634f7d2
2 changed files with 74 additions and 14 deletions

View file

@ -28,6 +28,31 @@ TEMPLATE = """
#include "compiler.h"
<%
def nirtypes(opcode):
split = opcode.split('.', 1)
if len(split) < 2:
split = opcode.split('_')
if len(split) <= 1:
return None
assert len(split) > 1
type = split[1]
if type[0] == 'v':
type = type[2:]
if type[0] == 'f':
return ['nir_type_float']
elif type[0] == 's':
return ['nir_type_int']
elif type[0] == 'u':
return ['nir_type_uint']
elif type[0] == 'i':
return ['nir_type_uint', 'nir_type_int']
else:
return None
def typesize(opcode):
if opcode[-3:] == '128':
return 128
@ -41,6 +66,29 @@ def typesize(opcode):
except:
return None
def condition(opcode, typecheck, sizecheck):
cond = ''
if typecheck == True:
cond += '('
types = nirtypes(opcode)
assert types != None
for T in types:
cond += "{}type == {}".format(' || ' if cond[-1] != '(' else '', T)
cond += ')'
if sizecheck == True:
cond += "{}bitsize == {}".format(' && ' if cond != '' else '', typesize(opcode))
cmpf_mods = ops[opcode]["modifiers"]["cmpf"] if "cmpf" in ops[opcode]["modifiers"] else None
if "cmpf" in ops[opcode]["modifiers"]:
cond += "{}(".format(' && ' if cond != '' else '')
for cmpf in ops[opcode]["modifiers"]["cmpf"]:
if cmpf != 'reserved':
cond += "{}cmpf == BI_CMPF_{}".format(' || ' if cond[-1] != '(' else '', cmpf.upper())
cond += ')'
return 'true' if cond == '' else cond
def to_suffix(op):
return "_to" if op["dests"] > 0 else ""
@ -81,23 +129,34 @@ bi_index bi_${opcode.replace('.', '_').lower()}(${signature(ops[opcode], modifie
<%
common_op = opcode.split('.')[0]
variants = [a for a in ops.keys() if a.split('.')[0] == common_op]
signatures = [signature(ops[op], modifiers, sized=True, no_dests=True) for op in variants]
signatures = [signature(ops[op], modifiers, no_dests=True) for op in variants]
homogenous = all([sig == signatures[0] for sig in signatures])
types = [nirtypes(x) for x in variants]
typeful = False
for t in types:
if t != types[0]:
typeful = True
sizes = [typesize(x) for x in variants]
sized = False
for size in sizes:
if size != sizes[0]:
sized = True
last = opcode == variants[-1]
%>
% if homogenous and len(variants) > 1 and last:
% for (suffix, temp, dests, ret) in (('_to', False, 1, 'instr *'), ('', True, 0, 'index')):
% if not temp or ops[opcode]["dests"] > 0:
static inline
bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, sized=True, no_dests=not dests)})
bi_${ret} bi_${common_op.replace('.', '_').lower()}${suffix if ops[opcode]['dests'] > 0 else ''}(${signature(ops[opcode], modifiers, typeful=typeful, sized=sized, no_dests=not dests)})
{
% for i, (variant, size) in enumerate(zip(variants, sizes)):
${"else " if i > 0 else ""} if (bitsize == ${size})
% for i, variant in enumerate(variants):
${"{}if ({})".format("else " if i > 0 else "", condition(variant, typeful, sized))}
return (bi_${variant.replace('.', '_').lower()}${to_suffix(ops[opcode])}(${arguments(ops[opcode], temp_dest = temp)}))${"->dest[0]" if temp else ""};
% endfor
else
unreachable("Invalid bitsize for ${common_op}");
unreachable("Invalid parameters for ${common_op}");
}
%endif
@ -122,10 +181,11 @@ def should_skip(mod):
def modifier_signature(op):
return sorted([m for m in op["modifiers"].keys() if not should_skip(m)])
def signature(op, modifiers, sized = False, no_dests = False):
def signature(op, modifiers, typeful = False, sized = False, no_dests = False):
return ", ".join(
["bi_builder *b"] +
(["unsigned bitsize"] if sized else []) +
(["nir_alu_type type"] if typeful == True else []) +
(["unsigned bitsize"] if sized == True else []) +
["bi_index dest{}".format(i) for i in range(0 if no_dests else op["dests"])] +
["bi_index src{}".format(i) for i in range(src_count(op))] +
["{} {}".format(

View file

@ -1670,7 +1670,7 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr)
if (sz == 8)
bi_mux_v4i8_to(b, dst, s2, s1, s0, BI_MUX_INT_ZERO);
else
bi_csel_to(b, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE);
bi_csel_to(b, nir_type_float, sz, dst, s0, bi_zero(), s1, s2, BI_CMPF_NE);
break;
case nir_op_ishl:
@ -1890,27 +1890,27 @@ bi_emit_alu(bi_builder *b, nir_alu_instr *instr)
break;
case nir_op_iadd:
bi_iadd_to(b, sz, dst, s0, s1, false);
bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, false);
break;
case nir_op_iadd_sat:
bi_iadd_to(b, sz, dst, s0, s1, true);
bi_iadd_to(b, nir_type_int, sz, dst, s0, s1, true);
break;
case nir_op_ihadd:
bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTN);
bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTN);
break;
case nir_op_irhadd:
bi_hadd_to(b, sz, dst, s0, s1, BI_ROUND_RTP);
bi_hadd_to(b, nir_type_int, sz, dst, s0, s1, BI_ROUND_RTP);
break;
case nir_op_isub:
bi_isub_to(b, sz, dst, s0, s1, false);
bi_isub_to(b, nir_type_int, sz, dst, s0, s1, false);
break;
case nir_op_isub_sat:
bi_isub_to(b, sz, dst, s0, s1, true);
bi_isub_to(b, nir_type_int, sz, dst, s0, s1, true);
break;
case nir_op_imul: