microsoft/compiler: Enable packed dot product intrinsics for SM6.4+

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22952>
This commit is contained in:
Jesse Natalie 2023-05-10 19:11:59 -07:00 committed by Marge Bot
parent 217bbdc4fd
commit a6ea08c542
4 changed files with 41 additions and 5 deletions

View file

@ -110,6 +110,7 @@ static struct predefined_func_descr predefined_funcs[] = {
{"dx.op.wavePrefixOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.dot4AddPacked", "i", "iiii", DXIL_ATTR_KIND_READ_NONE},
};
struct func_descr {

View file

@ -34,7 +34,7 @@ extern "C" {
bool dxil_nir_lower_8bit_conv(nir_shader *shader);
bool dxil_nir_lower_16bit_conv(nir_shader *shader);
bool dxil_nir_lower_x2b(nir_shader *shader);
bool dxil_nir_algebraic(nir_shader *shader);
bool dxil_nir_lower_fquantize2f16(nir_shader *shader);
bool dxil_nir_lower_ubo_to_temp(nir_shader *shader);
struct dxil_nir_lower_loads_stores_options {

View file

@ -29,6 +29,8 @@ import sys
import math
a = 'a'
b = 'b'
c = 'c'
# The nir_lower_bit_size() pass gets rid of all 8bit ALUs but insert new u2u8
# and i2i8 operations to convert the result back to the original type after the
@ -91,9 +93,13 @@ def remove_unsupported_casts(arr, bit_size, mask, max_unsigned_float, min_signed
remove_unsupported_casts(no_8bit_conv, 8, 0xff, 255.0, -128.0, 127.0)
remove_unsupported_casts(no_16bit_conv, 16, 0xffff, 65535.0, -32768.0, 32767.0)
lower_x2b = [
algebraic_ops = [
(('b2b32', 'a'), ('b2i32', 'a')),
(('b2b1', 'a'), ('ine', ('b2i32', a), 0)),
# We don't support the saturating versions of these
(('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c)),
(('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c)),
]
no_16bit_conv += [
@ -118,8 +124,8 @@ def run():
no_8bit_conv).render())
print(nir_algebraic.AlgebraicPass("dxil_nir_lower_16bit_conv",
no_16bit_conv).render())
print(nir_algebraic.AlgebraicPass("dxil_nir_lower_x2b",
lower_x2b).render())
print(nir_algebraic.AlgebraicPass("dxil_nir_algebraic",
algebraic_ops).render())
if __name__ == '__main__':
main()

View file

@ -179,6 +179,10 @@ dxil_get_nir_compiler_options(nir_shader_compiler_options *options,
options->lower_doubles_options = ~0;
if ((supported_int_sizes & 16) && (supported_float_sizes & 16))
options->support_16bit_alu = true;
if (shader_model_max >= SHADER_MODEL_6_4) {
options->has_sdot_4x8 = true;
options->has_udot_4x8 = true;
}
}
static bool
@ -373,6 +377,9 @@ enum dxil_intr {
DXIL_INTR_RAW_BUFFER_LOAD = 139,
DXIL_INTR_RAW_BUFFER_STORE = 140,
DXIL_INTR_DOT4_ADD_I8_PACKED = 163,
DXIL_INTR_DOT4_ADD_U8_PACKED = 164,
DXIL_INTR_ANNOTATE_HANDLE = 216,
DXIL_INTR_CREATE_HANDLE_FROM_BINDING = 217,
DXIL_INTR_CREATE_HANDLE_FROM_HEAP = 218,
@ -2499,6 +2506,25 @@ emit_bitfield_insert(struct ntd_context *ctx, nir_alu_instr *alu,
return true;
}
static bool
emit_dot4add_packed(struct ntd_context *ctx, nir_alu_instr *alu,
enum dxil_intr intr,
const struct dxil_value *src0,
const struct dxil_value *src1,
const struct dxil_value *accum)
{
const struct dxil_func *f = dxil_get_function(&ctx->mod, "dx.op.dot4AddPacked", DXIL_I32);
if (!f)
return false;
const struct dxil_value *srcs[] = { dxil_module_get_int32_const(&ctx->mod, intr), accum, src0, src1 };
const struct dxil_value *v = dxil_emit_call(&ctx->mod, f, srcs, ARRAY_SIZE(srcs));
if (!v)
return false;
store_alu_dest(ctx, alu, 0, v);
return true;
}
static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
const struct dxil_value *sel,
const struct dxil_value *val_true,
@ -2868,6 +2894,9 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true);
case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]);
case nir_op_sdot_4x8_iadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_I8_PACKED, src[0], src[1], src[2]);
case nir_op_udot_4x8_uadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_U8_PACKED, src[0], src[1], src[2]);
case nir_op_i2i1:
case nir_op_u2u1:
case nir_op_b2i16:
@ -6454,7 +6483,7 @@ optimize_nir(struct nir_shader *s, const struct nir_to_dxil_options *opts)
NIR_PASS(progress, s, nir_opt_cse);
NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true);
NIR_PASS(progress, s, nir_opt_algebraic);
NIR_PASS(progress, s, dxil_nir_lower_x2b);
NIR_PASS(progress, s, dxil_nir_algebraic);
if (s->options->lower_int64_options)
NIR_PASS(progress, s, nir_lower_int64);
NIR_PASS(progress, s, nir_lower_alu);