diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 1e3bb82ab11..a58c136399b 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3728,6 +3728,9 @@ typedef struct nir_shader_compiler_options { /** Backend supports sudot_4x8 opcodes. */ bool has_sudot_4x8; + /** Backend supports sdot_2x16 and udot_2x16 opcodes. */ + bool has_dot_2x16; + /* Whether to generate only scoped_barrier intrinsics instead of the set of * memory and control barrier intrinsics based on GLSL. */ diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index af32edda067..95470b1255f 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -1426,3 +1426,53 @@ opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp); """) + +# src0 and src1 are i16vec2 packed in an int32, and src2 is an int32. The int16 +# components are sign-extended to 32-bits, and a dot-product is performed on +# the resulting vectors. src2 is added to the result of the dot-product. +opcode("sdot_2x16_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const int32_t v0x = (int16_t)(src0 ); + const int32_t v0y = (int16_t)(src0 >> 16); + const int32_t v1x = (int16_t)(src1 ); + const int32_t v1y = (int16_t)(src1 >> 16); + + dst = (v0x * v1x) + (v0y * v1y) + src2; +""") + +# Like sdot_2x16_iadd, but unsigned. +opcode("udot_2x16_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32], + False, _2src_commutative, """ + const uint32_t v0x = (uint16_t)(src0 ); + const uint32_t v0y = (uint16_t)(src0 >> 16); + const uint32_t v1x = (uint16_t)(src1 ); + const uint32_t v1y = (uint16_t)(src1 >> 16); + + dst = (v0x * v1x) + (v0y * v1y) + src2; +""") + +# Like sdot_2x16_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff]. +opcode("sdot_2x16_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const int64_t v0x = (int16_t)(src0 ); + const int64_t v0y = (int16_t)(src0 >> 16); + const int64_t v1x = (int16_t)(src1 ); + const int64_t v1y = (int16_t)(src1 >> 16); + + const int64_t tmp = (v0x * v1x) + (v0y * v1y) + src2; + + dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp); +""") + +# Like udot_2x16_uadd, but the result is clampled to the range [0, 0xfffffffff]. +opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const uint64_t v0x = (uint16_t)(src0 ); + const uint64_t v0y = (uint16_t)(src0 >> 16); + const uint64_t v1x = (uint16_t)(src1 ); + const uint64_t v1y = (uint16_t)(src1 >> 16); + + const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + src2; + + dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp; +""") diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 1d9b1b2d1fb..fad49d2e5f1 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -197,6 +197,10 @@ optimizations = [ (('udot_4x8_uadd', a, 0, b), b), (('sdot_4x8_iadd_sat', a, 0, b), b), (('udot_4x8_uadd_sat', a, 0, b), b), + (('sdot_2x16_iadd', a, 0, b), b), + (('udot_2x16_uadd', a, 0, b), b), + (('sdot_2x16_iadd_sat', a, 0, b), b), + (('udot_2x16_uadd_sat', a, 0, b), b), # sudot_4x8_iadd is not commutative at all, so the patterns must be # duplicated with zeros on each of the first positions. @@ -208,6 +212,8 @@ optimizations = [ (('iadd', ('sdot_4x8_iadd(is_used_once)', a, b, '#c'), '#d'), ('sdot_4x8_iadd', a, b, ('iadd', c, d))), (('iadd', ('udot_4x8_uadd(is_used_once)', a, b, '#c'), '#d'), ('udot_4x8_uadd', a, b, ('iadd', c, d))), (('iadd', ('sudot_4x8_iadd(is_used_once)', a, b, '#c'), '#d'), ('sudot_4x8_iadd', a, b, ('iadd', c, d))), + (('iadd', ('sdot_2x16_iadd(is_used_once)', a, b, '#c'), '#d'), ('sdot_2x16_iadd', a, b, ('iadd', c, d))), + (('iadd', ('udot_2x16_uadd(is_used_once)', a, b, '#c'), '#d'), ('udot_2x16_uadd', a, b, ('iadd', c, d))), # Try to let constant folding eliminate the dot-product part. These are # safe because the dot product cannot overflow 32 bits. @@ -215,12 +221,18 @@ optimizations = [ (('iadd', ('udot_4x8_uadd', 'a(is_not_const)', b, 0), c), ('udot_4x8_uadd', a, b, c)), (('iadd', ('sudot_4x8_iadd', 'a(is_not_const)', b, 0), c), ('sudot_4x8_iadd', a, b, c)), (('iadd', ('sudot_4x8_iadd', a, 'b(is_not_const)', 0), c), ('sudot_4x8_iadd', a, b, c)), + (('iadd', ('sdot_2x16_iadd', 'a(is_not_const)', b, 0), c), ('sdot_2x16_iadd', a, b, c)), + (('iadd', ('udot_2x16_uadd', 'a(is_not_const)', b, 0), c), ('udot_2x16_uadd', a, b, c)), (('sdot_4x8_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sdot_4x8_iadd', a, b, 0), c)), (('udot_4x8_uadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('udot_4x8_uadd', a, b, 0), c)), (('sudot_4x8_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sudot_4x8_iadd', a, b, 0), c)), + (('sdot_2x16_iadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('sdot_2x16_iadd', a, b, 0), c)), + (('udot_2x16_uadd', '#a', '#b', 'c(is_not_const)'), ('iadd', ('udot_2x16_uadd', a, b, 0), c)), (('sdot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c), '!options->lower_iadd_sat'), (('udot_4x8_uadd_sat', '#a', '#b', 'c(is_not_const)'), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c), '!options->lower_uadd_sat'), (('sudot_4x8_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sudot_4x8_iadd', a, b, 0), c), '!options->lower_iadd_sat'), + (('sdot_2x16_iadd_sat', '#a', '#b', 'c(is_not_const)'), ('iadd_sat', ('sdot_2x16_iadd', a, b, 0), c), '!options->lower_iadd_sat'), + (('udot_2x16_uadd_sat', '#a', '#b', 'c(is_not_const)'), ('uadd_sat', ('udot_2x16_uadd', a, b, 0), c), '!options->lower_uadd_sat'), ] # Shorthand for the expansion of just the dot product part of the [iu]dp4a @@ -237,11 +249,17 @@ sudot_4x8_a_b = ('iadd', ('iadd', ('imul', ('extract_i8', a, 0), ('extract_u8', ('imul', ('extract_i8', a, 1), ('extract_u8', b, 1))), ('iadd', ('imul', ('extract_i8', a, 2), ('extract_u8', b, 2)), ('imul', ('extract_i8', a, 3), ('extract_u8', b, 3)))) +sdot_2x16_a_b = ('iadd', ('imul', ('extract_i16', a, 0), ('extract_i16', b, 0)), + ('imul', ('extract_i16', a, 1), ('extract_i16', b, 1))) +udot_2x16_a_b = ('iadd', ('imul', ('extract_u16', a, 0), ('extract_u16', b, 0)), + ('imul', ('extract_u16', a, 1), ('extract_u16', b, 1))) optimizations.extend([ (('sdot_4x8_iadd', a, b, c), ('iadd', sdot_4x8_a_b, c), '!options->has_dot_4x8'), (('udot_4x8_uadd', a, b, c), ('iadd', udot_4x8_a_b, c), '!options->has_dot_4x8'), (('sudot_4x8_iadd', a, b, c), ('iadd', sudot_4x8_a_b, c), '!options->has_sudot_4x8'), + (('sdot_2x16_iadd', a, b, c), ('iadd', sdot_2x16_a_b, c), '!options->has_dot_2x16'), + (('udot_2x16_uadd', a, b, c), ('iadd', udot_2x16_a_b, c), '!options->has_dot_2x16'), # For the unsigned dot-product, the largest possible value 4*(255*255) = # 0x3f804, so we don't have to worry about that intermediate result @@ -257,6 +275,9 @@ optimizations.extend([ (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', sdot_4x8_a_b, c), '!options->has_dot_4x8'), (('sudot_4x8_iadd_sat', a, b, c), ('iadd_sat', sudot_4x8_a_b, c), '!options->has_sudot_4x8'), + + (('udot_2x16_uadd_sat', a, b, c), ('uadd_sat', udot_2x16_a_b, c), '!options->has_dot_2x16'), + (('sdot_2x16_iadd_sat', a, b, c), ('iadd_sat', sdot_2x16_a_b, c), '!options->has_dot_2x16'), ]) # Float sizes