diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 0ee29855fd5..04ca38fb014 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3688,6 +3688,12 @@ typedef struct nir_shader_compiler_options { * for rect texture lowering. */ bool has_txs; + /** Backend supports sdot_4x8 and udot_4x8 opcodes. */ + bool has_dot_4x8; + + /** Backend supports sudot_4x8 opcodes. */ + bool has_sudot_4x8; + /* Whether to generate only scoped_barrier intrinsics instead of the set of * memory and control barrier intrinsics based on GLSL. */ diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index 6b2fc24300a..527cf6bb56a 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -1314,3 +1314,110 @@ unop_horiz("pack_double_2x32_dxil", 1, tuint64, 2, tuint32, "dst.x = src0.x | ((uint64_t)src0.y << 32);") unop_horiz("unpack_double_2x32_dxil", 2, tuint32, 1, tuint64, "dst.x = src0.x; dst.y = src0.x >> 32;") + +# src0 and src1 are i8vec4 packed in an int32, and src2 is an int32. The int8 +# components are sign-extended to 32-bits, and a dot-product is performed on +# the resulting vectors. src2 is added to the result of the dot-product. +opcode("sdot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const int32_t v0x = (int8_t)(src0 ); + const int32_t v0y = (int8_t)(src0 >> 8); + const int32_t v0z = (int8_t)(src0 >> 16); + const int32_t v0w = (int8_t)(src0 >> 24); + const int32_t v1x = (int8_t)(src1 ); + const int32_t v1y = (int8_t)(src1 >> 8); + const int32_t v1z = (int8_t)(src1 >> 16); + const int32_t v1w = (int8_t)(src1 >> 24); + + dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; +""") + +# Like sdot_4x8_iadd, but unsigned. +opcode("udot_4x8_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32], + False, _2src_commutative, """ + const uint32_t v0x = (uint8_t)(src0 ); + const uint32_t v0y = (uint8_t)(src0 >> 8); + const uint32_t v0z = (uint8_t)(src0 >> 16); + const uint32_t v0w = (uint8_t)(src0 >> 24); + const uint32_t v1x = (uint8_t)(src1 ); + const uint32_t v1y = (uint8_t)(src1 >> 8); + const uint32_t v1z = (uint8_t)(src1 >> 16); + const uint32_t v1w = (uint8_t)(src1 >> 24); + + dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; +""") + +# src0 is i8vec4 packed in an int32, src1 is u8vec4 packed in an int32, and +# src2 is an int32. The 8-bit components are extended to 32-bits, and a +# dot-product is performed on the resulting vectors. src2 is added to the +# result of the dot-product. +# +# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0 +# and source 1 mean that this opcode is not 2-source commutative +opcode("sudot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, "", """ + const int32_t v0x = (int8_t)(src0 ); + const int32_t v0y = (int8_t)(src0 >> 8); + const int32_t v0z = (int8_t)(src0 >> 16); + const int32_t v0w = (int8_t)(src0 >> 24); + const uint32_t v1x = (uint8_t)(src1 ); + const uint32_t v1y = (uint8_t)(src1 >> 8); + const uint32_t v1z = (uint8_t)(src1 >> 16); + const uint32_t v1w = (uint8_t)(src1 >> 24); + + dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; +""") + +# Like sdot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff]. +opcode("sdot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const int64_t v0x = (int8_t)(src0 ); + const int64_t v0y = (int8_t)(src0 >> 8); + const int64_t v0z = (int8_t)(src0 >> 16); + const int64_t v0w = (int8_t)(src0 >> 24); + const int64_t v1x = (int8_t)(src1 ); + const int64_t v1y = (int8_t)(src1 >> 8); + const int64_t v1z = (int8_t)(src1 >> 16); + const int64_t v1w = (int8_t)(src1 >> 24); + + const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; + + dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp); +""") + +# Like udot_4x8_uadd, but the result is clampled to the range [0, 0xfffffffff]. +opcode("udot_4x8_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, _2src_commutative, """ + const uint64_t v0x = (uint8_t)(src0 ); + const uint64_t v0y = (uint8_t)(src0 >> 8); + const uint64_t v0z = (uint8_t)(src0 >> 16); + const uint64_t v0w = (uint8_t)(src0 >> 24); + const uint64_t v1x = (uint8_t)(src1 ); + const uint64_t v1y = (uint8_t)(src1 >> 8); + const uint64_t v1z = (uint8_t)(src1 >> 16); + const uint64_t v1w = (uint8_t)(src1 >> 24); + + const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; + + dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp; +""") + +# Like sudot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff]. +# +# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0 +# and source 1 mean that this opcode is not 2-source commutative +opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32], + False, "", """ + const int64_t v0x = (int8_t)(src0 ); + const int64_t v0y = (int8_t)(src0 >> 8); + const int64_t v0z = (int8_t)(src0 >> 16); + const int64_t v0w = (int8_t)(src0 >> 24); + const uint64_t v1x = (uint8_t)(src1 ); + const uint64_t v1y = (uint8_t)(src1 >> 8); + const uint64_t v1z = (uint8_t)(src1 >> 16); + const uint64_t v1w = (uint8_t)(src1 >> 24); + + const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2; + + dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp); +""")