nir: add sdot_2x16 and udot_2x16 opcodes

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>
This commit is contained in:
Rhys Perry 2021-08-30 13:56:01 +01:00 committed by Marge Bot
parent ae00f5af61
commit 41ecef7855
3 changed files with 74 additions and 0 deletions

View file

@ -3728,6 +3728,9 @@ typedef struct nir_shader_compiler_options {
/** Backend supports sudot_4x8 opcodes. */
bool has_sudot_4x8;
/** Backend supports sdot_2x16 and udot_2x16 opcodes. */
bool has_dot_2x16;
/* Whether to generate only scoped_barrier intrinsics instead of the set of
* memory and control barrier intrinsics based on GLSL.
*/

View file

@ -1426,3 +1426,53 @@ opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
""")
# src0 and src1 are i16vec2 packed in an int32, and src2 is an int32. The int16
# components are sign-extended to 32-bits, and a dot-product is performed on
# the resulting vectors. src2 is added to the result of the dot-product.
opcode("sdot_2x16_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
False, _2src_commutative, """
const int32_t v0x = (int16_t)(src0 );
const int32_t v0y = (int16_t)(src0 >> 16);
const int32_t v1x = (int16_t)(src1 );
const int32_t v1y = (int16_t)(src1 >> 16);
dst = (v0x * v1x) + (v0y * v1y) + src2;
""")
# Like sdot_2x16_iadd, but unsigned.
opcode("udot_2x16_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
False, _2src_commutative, """
const uint32_t v0x = (uint16_t)(src0 );
const uint32_t v0y = (uint16_t)(src0 >> 16);
const uint32_t v1x = (uint16_t)(src1 );
const uint32_t v1y = (uint16_t)(src1 >> 16);
dst = (v0x * v1x) + (v0y * v1y) + src2;
""")
# Like sdot_2x16_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
opcode("sdot_2x16_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
False, _2src_commutative, """
const int64_t v0x = (int16_t)(src0 );
const int64_t v0y = (int16_t)(src0 >> 16);
const int64_t v1x = (int16_t)(src1 );
const int64_t v1y = (int16_t)(src1 >> 16);
const int64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
""")
# Like udot_2x16_uadd, but the result is clampled to the range [0, 0xfffffffff].
opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
False, _2src_commutative, """
const uint64_t v0x = (uint16_t)(src0 );
const uint64_t v0y = (uint16_t)(src0 >> 16);
const uint64_t v1x = (uint16_t)(src1 );
const uint64_t v1y = (uint16_t)(src1 >> 16);
const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
""")

View file

@ -197,6 +197,10 @@ optimizations = [
(('udot_4x8_uadd', a, 0, b), b),
(('sdot_4x8_iadd_sat', a, 0, b), b),
(('udot_4x8_uadd_sat', a, 0, b), b),
(('sdot_2x16_iadd', a, 0, b), b),
(('udot_2x16_uadd', a, 0, b), b),
(('sdot_2x16_iadd_sat', a, 0, b), b),
(('udot_2x16_uadd_sat', a, 0, b), b),
# sudot_4x8_iadd is not commutative at all, so the patterns must be
# duplicated with zeros on each of the first positions.
@ -208,6 +212,8 @@ optimizations = [
(('iadd', ('sdot_4x8_iadd(is_used_once)', a, b, '#c'), '#d'), ('sdot_4x8_iadd', a, b, ('iadd', c, d))),
(('iadd', ('udot_4x8_uadd(is_used_once)', a, b, '#c'), '#d'), ('udot_4x8_uadd', a, b, ('iadd', c, d))),
(('iadd', ('sudot_4x8_iadd(is_used_once)', a, b, '#c'), '#d'), ('sudot_4x8_iadd', a, b, ('iadd', c, d))),
(('iadd', ('sdot_2x16_iadd(is_used_once)', a, b, '#c'), '#d'), ('sdot_2x16_iadd', a, b, ('iadd', c, d))),
(('iadd', ('udot_2x16_uadd(is_used_once)', a, b, '#c'), '#d'), ('udot_2x16_uadd', a, b, ('iadd', c, d))),
# Try to let constant folding eliminate the dot-product part. These are
# safe because the dot product cannot overflow 32 bits.
@ -215,12 +221,18 @@ optimizations = [
(('iadd', ('udot_4x8_uadd', 'a(is_not_const)', b, 0), c), ('udot_4x8_uadd', a, b, c)),
(('iadd', ('sudot_4x8_iadd', 'a(is_not_const)', b, 0), c), ('sudot_4x8_iadd', a, b, c)),
(('iadd', ('sudot_4x8_iadd', a, 'b(is_not_const)', 0), c), ('sudot_4x8_iadd', a, b, c)),
(('iadd', ('sdot_2x16_iadd', 'a(is_not_const)', b, 0), c), ('sdot_2x16_iadd', a, b, c)),
(('iadd', ('udot_2x16_uadd', 'a(is_not_const)', b, 0), c), ('udot_2x16_uadd', a, b, c)),
(('sdot_4x8_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sdot_4x8_iadd', a, b, 0), c)),
(('udot_4x8_uadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('udot_4x8_uadd', a, b, 0), c)),
(('sudot_4x8_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sudot_4x8_iadd', a, b, 0), c)),
(('sdot_2x16_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sdot_2x16_iadd', a, b, 0), c)),
(('udot_2x16_uadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('udot_2x16_uadd', a, b, 0), c)),
(('sdot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c), '!options->lower_iadd_sat'),
(('udot_4x8_uadd_sat', '#a', '#b', 'c(is_not_const)'), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c), '!options->lower_uadd_sat'),
(('sudot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sudot_4x8_iadd', a, b, 0), c), '!options->lower_iadd_sat'),
(('sdot_2x16_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sdot_2x16_iadd', a, b, 0), c), '!options->lower_iadd_sat'),
(('udot_2x16_uadd_sat', '#a', '#b', 'c(is_not_const)'), ('uadd_sat', ('udot_2x16_uadd', a, b, 0), c), '!options->lower_uadd_sat'),
]
# Shorthand for the expansion of just the dot product part of the [iu]dp4a
@ -237,11 +249,17 @@ sudot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_u8',
('imul', ('extract_i8', a, 1), ('extract_u8', b, 1))),
('iadd', ('imul', ('extract_i8', a, 2), ('extract_u8', b, 2)),
('imul', ('extract_i8', a, 3), ('extract_u8', b, 3))))
sdot_2x16_a_b = ('iadd', ('imul', ('extract_i16', a, 0), ('extract_i16', b, 0)),
('imul', ('extract_i16', a, 1), ('extract_i16', b, 1)))
udot_2x16_a_b = ('iadd', ('imul', ('extract_u16', a, 0), ('extract_u16', b, 0)),
('imul', ('extract_u16', a, 1), ('extract_u16', b, 1)))
optimizations.extend([
(('sdot_4x8_iadd', a, b, c), ('iadd', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
(('udot_4x8_uadd', a, b, c), ('iadd', udot_4x8_a_b, c), '!options->has_dot_4x8'),
(('sudot_4x8_iadd', a, b, c), ('iadd', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
(('sdot_2x16_iadd', a, b, c), ('iadd', sdot_2x16_a_b, c), '!options->has_dot_2x16'),
(('udot_2x16_uadd', a, b, c), ('iadd', udot_2x16_a_b, c), '!options->has_dot_2x16'),
# For the unsigned dot-product, the largest possible value 4*(255*255) =
# 0x3f804, so we don't have to worry about that intermediate result
@ -257,6 +275,9 @@ optimizations.extend([
(('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', sdot_4x8_a_b, c), '!options->has_dot_4x8'),
(('sudot_4x8_iadd_sat', a, b, c), ('iadd_sat', sudot_4x8_a_b, c), '!options->has_sudot_4x8'),
(('udot_2x16_uadd_sat', a, b, c), ('uadd_sat', udot_2x16_a_b, c), '!options->has_dot_2x16'),
(('sdot_2x16_iadd_sat', a, b, c), ('iadd_sat', sdot_2x16_a_b, c), '!options->has_dot_2x16'),
])
# Float sizes