pco: add support for various selection, complex, trig ops

Signed-off-by: Simon Perretta <simon.perretta@imgtec.com>
Acked-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36412>
This commit is contained in:
Simon Perretta 2025-01-01 22:03:55 +00:00 committed by Marge Bot
parent 97f167f227
commit 8ec174b3f9
7 changed files with 668 additions and 0 deletions

View file

@ -979,6 +979,8 @@ if (!isnormal(dst))
dst = copysignf(0.0f, src0);
""")
binop("fcopysign_pco", tfloat, "", "bit_size == 64 ? copysign(src0, src1) : copysignf(src0, src1)")
binop_horiz("vec2", 2, tuint, 1, tuint, 1, tuint, """
dst.x = src0.x;
dst.y = src1.x;

View file

@ -2670,12 +2670,21 @@ static inline bool pco_should_skip_pass(const char *pass)
/** Integer one. */
#define pco_one pco_ref_hwreg(1, PCO_REG_CLASS_CONST)
/** Integer 31. */
#define pco_31 pco_ref_hwreg(31, PCO_REG_CLASS_CONST)
/** Integer -1/true/0xffffffff. */
#define pco_true pco_ref_hwreg(143, PCO_REG_CLASS_CONST)
/** Float 1. */
#define pco_fone pco_ref_hwreg(64, PCO_REG_CLASS_CONST)
/** Float -1. */
#define pco_fnegone pco_ref_neg(pco_ref_hwreg(64, PCO_REG_CLASS_CONST))
/** Float infinity. */
#define pco_finf pco_ref_hwreg(142, PCO_REG_CLASS_CONST)
/* Printing. */
void pco_print_ref(pco_shader *shader, pco_ref ref);
void pco_print_instr(pco_shader *shader, pco_instr *instr);

View file

@ -297,6 +297,16 @@ enum_map(OM_ATOM_OP.t, F_ATOMIC_OP, [
('xor', 'xor'),
])
enum_map(OM_FRED_TYPE.t, F_RED_TYPE, [
('sin', 'sin'),
('cos', 'cos'),
])
enum_map(OM_FRED_PART.t, F_RED_PART, [
('a', 'a'),
('b', 'b'),
])
class OpRef(object):
def __init__(self, ref_type, index, mods):
self.type = ref_type
@ -872,6 +882,90 @@ encode_map(O_FRCP,
op_ref_maps=[('0', ['w0'], ['s0'])]
)
encode_map(O_FRSQ,
encodings=[
(I_SNGL_EXT, [
('sngl_op', 'rsq'),
('s0neg', (RM_NEG, SRC(0))),
('s0abs', (RM_ABS, SRC(0)))
]),
(I_SNGL, [('sngl_op', 'rsq')], [
(RM_NEG, SRC(0), '== false'),
(RM_ABS, SRC(0), '== false')
])
],
op_ref_maps=[('0', ['w0'], ['s0'])]
)
encode_map(O_FLOG,
encodings=[
(I_SNGL_EXT, [
('sngl_op', 'log'),
('s0neg', (RM_NEG, SRC(0))),
('s0abs', (RM_ABS, SRC(0)))
]),
(I_SNGL, [('sngl_op', 'log')], [
(RM_NEG, SRC(0), '== false'),
(RM_ABS, SRC(0), '== false')
])
],
op_ref_maps=[('0', ['w0'], ['s0'])]
)
encode_map(O_FLOGCN,
encodings=[
(I_SNGL_EXT, [
('sngl_op', 'logcn'),
('s0neg', (RM_NEG, SRC(0))),
('s0abs', (RM_ABS, SRC(0)))
]),
(I_SNGL, [('sngl_op', 'logcn')], [
(RM_NEG, SRC(0), '== false'),
(RM_ABS, SRC(0), '== false')
])
],
op_ref_maps=[('0', ['w0'], ['s0'])]
)
encode_map(O_FEXP,
encodings=[
(I_SNGL_EXT, [
('sngl_op', 'exp'),
('s0neg', (RM_NEG, SRC(0))),
('s0abs', (RM_ABS, SRC(0)))
]),
(I_SNGL, [('sngl_op', 'exp')], [
(RM_NEG, SRC(0), '== false'),
(RM_ABS, SRC(0), '== false')
])
],
op_ref_maps=[('0', ['w0'], ['s0'])]
)
encode_map(O_FRED,
encodings=[
(I_FRED, [
('red_part', OM_FRED_PART),
('iter', ('pco_ref_get_imm', SRC(2))),
('red_type', OM_FRED_TYPE),
('pwen', ('!pco_ref_is_null', DEST(2))),
('s0neg', (RM_NEG, SRC(0))),
('s0abs', (RM_NEG, SRC(0)))
])
],
op_ref_maps=[('0', [['w0', '_'], ['w1', '_'], ['p0', '_']], ['s0', ['s3', '_'], 'imm'])]
)
encode_map(O_FSINC,
encodings=[
(I_SNGL, [('sngl_op', 'sinc')], [
(RM_NEG, SRC(0), '== false'),
(RM_ABS, SRC(0), '== false')
])
],
op_ref_maps=[('0', ['w0', 'p0'], ['s0'])]
)
encode_map(O_MBYP,
encodings=[
(I_SNGL_EXT, [
@ -1458,6 +1552,108 @@ group_map(O_FRCP,
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_FRSQ,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FRSQ)],
srcs=[('s[0]', ('0', SRC(0)), 's0')],
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_FLOG,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FLOG)],
srcs=[('s[0]', ('0', SRC(0)), 's0')],
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_FLOGCN,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FLOGCN)],
srcs=[('s[0]', ('0', SRC(0)), 's0')],
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_FEXP,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FEXP)],
srcs=[('s[0]', ('0', SRC(0)), 's0')],
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_FRED,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', ('!pco_ref_is_null', DEST(1))),
('w0p', ('!pco_ref_is_null', DEST(0))),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FRED)],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[3]', ('0', SRC(1)), 's3')
],
dests=[
('w[0]', ('0', DEST(0)), 'w0'),
('w[1]', ('0', DEST(1)), 'w1')
]
)
group_map(O_FSINC,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[('0', O_FSINC)],
srcs=[('s[0]', ('0', SRC(0)), 's0')],
dests=[('w[0]', ('0', DEST(0)), 'w0')]
)
group_map(O_MBYP,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0'),
@ -1777,6 +1973,203 @@ group_map(O_BCMP,
]
)
group_map(O_BCSEL,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_MBYP, ['ft0'], [SRC(1)]),
('1', O_MBYP, ['ft1'], [SRC(2)]),
('2_tst', O_TST, ['ftt', '_'], [SRC(0), '_'], [(OM_TST_OP_MAIN, 'zero'), (OM_TST_TYPE_MAIN, 'u32'), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'ft1', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('2_tst', SRC(0)), 'is1'),
('s[3]', ('1', SRC(0)), 's3'),
],
iss=[
('is[0]', 's1'),
('is[1]', 'fte'),
('is[4]', 'ft0'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_CSEL,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_MBYP, ['ft0'], [SRC(1)]),
('1', O_MBYP, ['ft1'], [SRC(2)]),
('2_tst', O_TST, ['ftt', '_'], [SRC(0), '_'], [(OM_TST_OP_MAIN, OM_TST_OP_MAIN), (OM_TST_TYPE_MAIN, OM_TST_TYPE_MAIN), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'ft0', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('2_tst', SRC(0)), 'is1'),
('s[3]', ('1', SRC(0)), 's3'),
],
iss=[
('is[0]', 's1'),
('is[1]', 'fte'),
('is[4]', 'ft1'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_PSEL_TRIG,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_IMADD64, ['ft0', 'fte'], ['pco_zero', 'pco_zero', 'pco_zero', 'is0', SRC(0)]),
('1', O_FMUL, ['ft1'], [SRC(1), SRC(2)]),
('2_tst', O_TST, ['ftt', '_'], ['is1', '_'], [(OM_TST_OP_MAIN, 'zero'), (OM_TST_TYPE_MAIN, 'u32'), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'fte', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('0', SRC(1)), 's1'),
('s[2]', ('0', SRC(2)), 's2'),
('s[3]', ('1', SRC(0)), 's3'),
('s[4]', ('1', SRC(1)), 's4'),
],
iss=[
('is[0]', 's3'),
('is[1]', 'ft0'),
('is[4]', 'ft1'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_FSIGN,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_FMUL, ['ft0'], [SRC(0), 'pco_finf'], [(OM_SAT, True)]),
('1', O_MBYP, ['ft1'], ['pco_fnegone']),
('2_tst', O_TST, ['ftt', '_'], ['is1', '_'], [(OM_TST_OP_MAIN, 'gezero'), (OM_TST_TYPE_MAIN, 's32'), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'ft0', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('0', SRC(1)), 's1'),
('s[3]', ('1', SRC(0)), 's3'),
],
iss=[
('is[0]', 's0'),
('is[1]', 'fte'),
('is[4]', 'ft1'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_ISIGN,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_IMADD64, ['_', 'fte'], [SRC(0), 'pco_one', 'pco_zero', 'is0', '_'], [(OM_S, True)]),
('1', O_MBYP, ['ft1'], ['pco_one']),
('2_tst', O_TST, ['ftt', '_'], ['is1', '_'], [(OM_TST_OP_MAIN, 'gzero'), (OM_TST_TYPE_MAIN, 's32'), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'ft1', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('0', SRC(1)), 's1'),
('s[2]', ('0', SRC(2)), 's2'),
('s[3]', ('1', SRC(0)), 's3'),
],
iss=[
('is[0]', 's2'),
('is[1]', 'ft0'),
('is[4]', 'fte'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_FCEIL,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_FADD, ['ft0'], [SRC(0, [RM_FLR]), 'pco_fone']),
('1', O_FADD, ['ft1'], [SRC(0, [RM_FLR]), 'pco_zero']),
('2_tst', O_TST, ['ftt', '_'], ['is1', 'is2'], [(OM_TST_OP_MAIN, 'equal'), (OM_TST_TYPE_MAIN, 'f32'), (OM_PHASE2END, True)]),
('2_mov', O_MOVC, [DEST(0), '_'], ['ftt', 'ft1', 'is4', '_', '_'])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('0', SRC(1)), 's1'),
('s[3]', ('1', SRC(0)), 's3'),
('s[4]', ('1', SRC(1)), 's4'),
],
iss=[
('is[0]', 's0'),
('is[1]', 'fte'),
('is[2]', 'ft1'),
('is[4]', 'ft0'),
],
dests=[
('w[0]', ('2_mov', DEST(0)), 'w0'),
]
)
group_map(O_MIN,
hdr=(I_IGRP_HDR_MAIN, [
('oporg', 'p0_p1_p2'),
@ -2150,6 +2543,30 @@ group_map(O_SHIFT,
dests=[('w[0]', ('2', DEST(0)), 'ft5')]
)
group_map(O_COPYSIGN,
hdr=(I_IGRP_HDR_BITWISE, [
('opcnt', ['p0', 'p1']),
('olchk', OM_OLCHK),
('w1p', False),
('w0p', True),
('cc', OM_EXEC_CND),
('end', OM_END),
('atom', OM_ATOM),
('rpt', OM_RPT)
]),
enc_ops=[
('0', O_MSK_BBYP0S1, ['ft0', 'ft1', 'ft2'], ['pco_31', 'pco_zero', SRC(0)]),
('1', O_LOGICAL, [DEST(0)], ['ft1', 'ft2', 'ft1_invert', SRC(1)], [(OM_LOGIOP, 'or')])
],
srcs=[
('s[0]', ('0', SRC(0)), 's0'),
('s[1]', ('0', SRC(1)), 's1'),
('s[2]', ('0', SRC(2)), 's2'),
('s[3]', ('1', SRC(3)), 's3')
],
dests=[('w[0]', ('1', DEST(0)), 'ft4')]
)
group_map(O_WOP,
hdr=(I_IGRP_HDR_CONTROL, [
('olchk', False),

View file

@ -37,7 +37,14 @@ static const nir_shader_compiler_options nir_options = {
.instance_id_includes_base_index = true,
.lower_fdiv = true,
.lower_ffract = true,
.lower_fquantize2f16 = true,
.lower_flrp32 = true,
.lower_fmod = true,
.lower_fpow = true,
.lower_fsqrt = true,
.lower_ftrunc = true,
.lower_ldexp = true,
.lower_layer_fs_input_to_sysval = true,
.compact_arrays = true,
.scalarize_ddx = true,
@ -126,6 +133,9 @@ void pco_preprocess_nir(pco_ctx *ctx, nir_shader *nir)
.allow_fp16 = false,
});
NIR_PASS(_, nir, nir_lower_frexp);
NIR_PASS(_, nir, nir_lower_flrp, 32, true);
NIR_PASS(_,
nir,
nir_remove_dead_variables,

View file

@ -11,6 +11,27 @@ b = 'b'
lower_algebraic = []
lower_algebraic_late = []
def lowered_fround_even(src):
abs_src = ('fabs', src)
ffloor_temp = ('ffloor', abs_src)
ffract_temp = ('ffract', abs_src)
ceil_temp = ('fadd', ffloor_temp, 1.0)
even_temp = ('fmul', ffloor_temp, 0.5)
even_ffract_temp = ('ffract', even_temp)
ishalf_temp = ('feq', ffract_temp, 0.5)
ffract_temp = ('bcsel', ishalf_temp, even_ffract_temp, ffract_temp)
lesshalf_temp = ('flt', ffract_temp, 0.5)
result_temp = ('bcsel', lesshalf_temp, ffloor_temp, ceil_temp)
res = ('fcopysign_pco', result_temp, src)
return res
lower_algebraic.append((('fround_even', a), lowered_fround_even(a)))
lower_scmp = [
# Float comparisons + bool conversions.
(('b2f32', ('flt', a, b)), ('slt', a, 'b@32'), '!options->lower_scmp'),

View file

@ -300,6 +300,16 @@ OM_CND = op_mod_enum('cnd', [
('p0_false', 'if(!p0)'),
])
OM_FRED_TYPE = op_mod_enum('fred_type', [
'sin',
'cos',
])
OM_FRED_PART = op_mod_enum('fred_part', [
'a',
'b',
])
# Ops.
OM_ALU = [OM_OLCHK, OM_EXEC_CND, OM_END, OM_ATOM, OM_RPT]
@ -310,6 +320,12 @@ O_FADD = hw_op('fadd', OM_ALU + [OM_SAT], 1, 2, [], [[RM_ABS, RM_NEG, RM_FLR], [
O_FMUL = hw_op('fmul', OM_ALU + [OM_SAT], 1, 2, [], [[RM_ABS, RM_NEG, RM_FLR], [RM_ABS]])
O_FMAD = hw_op('fmad', OM_ALU + [OM_SAT, OM_LP], 1, 3, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG], [RM_ABS, RM_NEG, RM_FLR]])
O_FRCP = hw_op('frcp', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FRSQ = hw_op('frsq', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FLOG = hw_op('flog', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FLOGCN = hw_op('flogcn', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FEXP = hw_op('fexp', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FRED = hw_op('fred', OM_ALU + [OM_FRED_TYPE, OM_FRED_PART], 3, 3, [], [[RM_ABS, RM_NEG]])
O_FSINC = hw_op('fsinc', OM_ALU, 2, 1)
O_MBYP = hw_op('mbyp', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FDSX = hw_op('fdsx', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
O_FDSXF = hw_op('fdsxf', OM_ALU, 1, 1, [], [[RM_ABS, RM_NEG]])
@ -350,6 +366,8 @@ O_MOVI32 = hw_op('movi32', OM_ALU, 1, 1)
O_LOGICAL = hw_op('logical', OM_ALU + [OM_LOGIOP], 1, 4)
O_SHIFT = hw_op('shift', OM_ALU + [OM_SHIFTOP], 1, 3)
O_COPYSIGN = hw_op('copysign', OM_ALU, 1, 2)
O_BBYP0BM = hw_direct_op('bbyp0bm', [], 2, 2)
O_BBYP0BM_IMM32 = hw_direct_op('bbyp0bm_imm32', [], 2, 2)
O_BBYP0S1 = hw_direct_op('bbyp0s1', [], 1, 1)
@ -375,6 +393,12 @@ O_BR = hw_op('br', [OM_EXEC_CND, OM_BRANCH_CND, OM_LINK], has_target_cf_node=Tru
# Combination (> 1 instructions per group).
O_SCMP = hw_op('scmp', OM_ALU + [OM_TST_OP_MAIN], 1, 2, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_BCMP = hw_op('bcmp', OM_ALU + [OM_TST_OP_MAIN, OM_TST_TYPE_MAIN], 1, 2, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_BCSEL = hw_op('bcsel', OM_ALU, 1, 3, [], [[], [RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_CSEL = hw_op('csel', OM_ALU + [OM_TST_OP_MAIN, OM_TST_TYPE_MAIN], 1, 3, [], [[], [RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_PSEL_TRIG = hw_op('psel_trig', OM_ALU, 1, 3)
O_FSIGN = hw_op('fsign', OM_ALU, 1, 1)
O_ISIGN = hw_op('isign', OM_ALU, 1, 1)
O_FCEIL = hw_op('fceil', OM_ALU, 1, 1)
O_MIN = hw_op('min', OM_ALU + [OM_TST_TYPE_MAIN], 1, 2, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_MAX = hw_op('max', OM_ALU + [OM_TST_TYPE_MAIN], 1, 2, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])
O_IADD32 = hw_op('iadd32', OM_ALU + [OM_S], 1, 3, [], [[RM_ABS, RM_NEG], [RM_ABS, RM_NEG]])

View file

@ -886,6 +886,18 @@ static pco_instr *pco_trans_nir_vec(trans_ctx *tctx,
static inline enum pco_tst_op_main to_tst_op_main(nir_op op)
{
switch (op) {
case nir_op_fcsel:
case nir_op_icsel_eqz:
return PCO_TST_OP_MAIN_ZERO;
case nir_op_fcsel_gt:
case nir_op_i32csel_gt:
return PCO_TST_OP_MAIN_GZERO;
case nir_op_fcsel_ge:
case nir_op_i32csel_ge:
return PCO_TST_OP_MAIN_GEZERO;
case nir_op_slt:
case nir_op_flt:
case nir_op_ilt:
@ -918,6 +930,10 @@ static inline enum pco_tst_op_main to_tst_op_main(nir_op op)
static inline enum pco_tst_type_main to_tst_type_main(nir_op op, pco_ref src)
{
switch (op) {
case nir_op_fcsel:
case nir_op_fcsel_gt:
case nir_op_fcsel_ge:
case nir_op_slt:
case nir_op_sge:
case nir_op_seq:
@ -932,6 +948,10 @@ static inline enum pco_tst_type_main to_tst_type_main(nir_op op, pco_ref src)
case nir_op_fmax:
return PCO_TST_TYPE_MAIN_F32;
case nir_op_icsel_eqz:
case nir_op_i32csel_gt:
case nir_op_i32csel_ge:
case nir_op_ilt:
case nir_op_ige:
case nir_op_ieq:
@ -997,6 +1017,39 @@ trans_cmp(trans_ctx *tctx, nir_op op, pco_ref dest, pco_ref src0, pco_ref src1)
.tst_type_main = tst_type_main);
}
/**
* \brief Translates a NIR {i,f}csel op into PCO.
*
* \param[in,out] tctx Translation context.
* \param[in] op The NIR op.
* \param[in] src Instruction source.
* \return The translated PCO instruction.
*/
static pco_instr *trans_csel(trans_ctx *tctx,
nir_op op,
pco_ref dest,
pco_ref src0,
pco_ref src1,
pco_ref src2)
{
enum pco_tst_op_main tst_op_main = to_tst_op_main(op);
enum pco_tst_type_main tst_type_main = to_tst_type_main(op, src0);
if (op == nir_op_fcsel) {
pco_ref tmp = src2;
src2 = src1;
src1 = tmp;
}
return pco_csel(&tctx->b,
dest,
src0,
src1,
src2,
.tst_op_main = tst_op_main,
.tst_type_main = tst_type_main);
}
/**
* \brief Translates a NIR bitwise logical op into PCO.
*
@ -1141,6 +1194,76 @@ static pco_instr *trans_min_max(trans_ctx *tctx,
.tst_type_main = tst_type_main);
}
/**
* \brief Translates a NIR trigonometric op into PCO.
*
* \param[in,out] tctx Translation context.
* \param[in] op The NIR op.
* \param[in] dest Instruction destination.
* \param[in] src Instruction source.
* \return The translated PCO instruction.
*/
static pco_instr *
trans_trig(trans_ctx *tctx, nir_op op, pco_ref dest, pco_ref src)
{
assert(pco_ref_get_chans(dest) == 1);
assert(pco_ref_get_bits(dest) == 32);
enum pco_fred_type fred_type;
switch (op) {
case nir_op_fsin:
fred_type = PCO_FRED_TYPE_SIN;
break;
case nir_op_fcos:
fred_type = PCO_FRED_TYPE_COS;
break;
/* TODO: arctan, arctanc, sinc, cosc. */
default:
UNREACHABLE("");
}
pco_ref fred_dest_a = pco_ref_new_ssa32(tctx->func);
pco_fred(&tctx->b,
pco_ref_null(),
fred_dest_a,
pco_ref_null(),
src,
pco_ref_null(),
pco_ref_imm8(0),
.fred_type = fred_type,
.fred_part = PCO_FRED_PART_A);
pco_ref fred_dest_b = pco_ref_new_ssa32(tctx->func);
pco_fred(&tctx->b,
fred_dest_b,
pco_ref_null(),
pco_ref_null(),
src,
fred_dest_a,
pco_ref_imm8(0),
.fred_type = fred_type,
.fred_part = PCO_FRED_PART_B);
pco_ref trig_dest = pco_ref_new_ssa32(tctx->func);
switch (op) {
case nir_op_fsin:
case nir_op_fcos:
pco_fsinc(&tctx->b, trig_dest, pco_ref_pred(PCO_PRED_P0), fred_dest_b);
break;
default:
UNREACHABLE("");
}
return pco_psel_trig(&tctx->b,
dest,
pco_ref_pred(PCO_PRED_P0),
trig_dest,
fred_dest_b);
}
/**
* \brief Translates a NIR alu instruction into PCO.
*
@ -1172,12 +1295,19 @@ static pco_instr *trans_alu(trans_ctx *tctx, nir_alu_instr *alu)
case nir_op_ffloor:
instr = pco_fflr(&tctx->b, dest, src[0]);
break;
case nir_op_fceil:
instr = pco_fceil(&tctx->b, dest, src[0]);
break;
case nir_op_fadd:
instr = pco_fadd(&tctx->b, dest, src[0], src[1]);
break;
case nir_op_fsub:
instr = pco_fadd(&tctx->b, dest, pco_ref_neg(src[1]), src[0]);
break;
case nir_op_fmul:
instr = pco_fmul(&tctx->b, dest, src[0], src[1]);
break;
@ -1190,6 +1320,39 @@ static pco_instr *trans_alu(trans_ctx *tctx, nir_alu_instr *alu)
instr = pco_frcp(&tctx->b, dest, src[0]);
break;
case nir_op_frsq:
instr = pco_frsq(&tctx->b, dest, src[0]);
break;
case nir_op_fexp2:
instr = pco_fexp(&tctx->b, dest, src[0]);
break;
case nir_op_flog2:
instr = pco_flog(&tctx->b, dest, src[0]);
break;
case nir_op_fsign:
instr = pco_fsign(&tctx->b, dest, src[0]);
break;
case nir_op_fsat:
instr = pco_fadd(&tctx->b, dest, src[0], pco_zero, .sat = true);
break;
case nir_op_fsin:
case nir_op_fcos:
instr = trans_trig(tctx, alu->op, dest, src[0]);
break;
case nir_op_isign:
instr = pco_isign(&tctx->b, dest, src[0]);
break;
case nir_op_fcopysign_pco:
instr = pco_copysign(&tctx->b, dest, src[0], src[1]);
break;
case nir_op_iadd:
instr = pco_iadd32(&tctx->b, dest, src[0], src[1], pco_ref_null());
break;
@ -1256,6 +1419,28 @@ static pco_instr *trans_alu(trans_ctx *tctx, nir_alu_instr *alu)
instr = trans_cmp(tctx, alu->op, dest, src[0], src[1]);
break;
case nir_op_bcsel:
instr = pco_bcsel(&tctx->b, dest, src[0], src[1], src[2]);
break;
case nir_op_fcsel:
case nir_op_fcsel_gt:
case nir_op_fcsel_ge:
case nir_op_icsel_eqz:
case nir_op_i32csel_gt:
case nir_op_i32csel_ge:
instr = trans_csel(tctx, alu->op, dest, src[0], src[1], src[2]);
break;
case nir_op_b2f32:
instr = pco_bcsel(&tctx->b, dest, src[0], pco_fone, pco_zero);
break;
case nir_op_b2i32:
instr = pco_bcsel(&tctx->b, dest, src[0], pco_one, pco_zero);
break;
case nir_op_iand:
case nir_op_ior:
case nir_op_ixor: