diff --git a/src/microsoft/compiler/dxil_function.c b/src/microsoft/compiler/dxil_function.c index f4704adf559..f72ef1ff745 100644 --- a/src/microsoft/compiler/dxil_function.c +++ b/src/microsoft/compiler/dxil_function.c @@ -110,6 +110,7 @@ static struct predefined_func_descr predefined_funcs[] = { {"dx.op.wavePrefixOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND}, {"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND}, {"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND}, +{"dx.op.dot4AddPacked", "i", "iiii", DXIL_ATTR_KIND_READ_NONE}, }; struct func_descr { diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index ff404af43aa..9c9d3408588 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -34,7 +34,7 @@ extern "C" { bool dxil_nir_lower_8bit_conv(nir_shader *shader); bool dxil_nir_lower_16bit_conv(nir_shader *shader); -bool dxil_nir_lower_x2b(nir_shader *shader); +bool dxil_nir_algebraic(nir_shader *shader); bool dxil_nir_lower_fquantize2f16(nir_shader *shader); bool dxil_nir_lower_ubo_to_temp(nir_shader *shader); struct dxil_nir_lower_loads_stores_options { diff --git a/src/microsoft/compiler/dxil_nir_algebraic.py b/src/microsoft/compiler/dxil_nir_algebraic.py index 9fd9ca54f21..868f799a0f7 100644 --- a/src/microsoft/compiler/dxil_nir_algebraic.py +++ b/src/microsoft/compiler/dxil_nir_algebraic.py @@ -29,6 +29,8 @@ import sys import math a = 'a' +b = 'b' +c = 'c' # The nir_lower_bit_size() pass gets rid of all 8bit ALUs but insert new u2u8 # and i2i8 operations to convert the result back to the original type after the @@ -91,9 +93,13 @@ def remove_unsupported_casts(arr, bit_size, mask, max_unsigned_float, min_signed remove_unsupported_casts(no_8bit_conv, 8, 0xff, 255.0, -128.0, 127.0) remove_unsupported_casts(no_16bit_conv, 16, 0xffff, 65535.0, -32768.0, 32767.0) -lower_x2b = [ +algebraic_ops = [ (('b2b32', 'a'), ('b2i32', 'a')), (('b2b1', 'a'), ('ine', ('b2i32', a), 0)), + + # We don't support the saturating versions of these + (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c)), + (('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c)), ] no_16bit_conv += [ @@ -118,8 +124,8 @@ def run(): no_8bit_conv).render()) print(nir_algebraic.AlgebraicPass("dxil_nir_lower_16bit_conv", no_16bit_conv).render()) - print(nir_algebraic.AlgebraicPass("dxil_nir_lower_x2b", - lower_x2b).render()) + print(nir_algebraic.AlgebraicPass("dxil_nir_algebraic", + algebraic_ops).render()) if __name__ == '__main__': main() diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index 8b8e718c7b8..4f902d27606 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -179,6 +179,10 @@ dxil_get_nir_compiler_options(nir_shader_compiler_options *options, options->lower_doubles_options = ~0; if ((supported_int_sizes & 16) && (supported_float_sizes & 16)) options->support_16bit_alu = true; + if (shader_model_max >= SHADER_MODEL_6_4) { + options->has_sdot_4x8 = true; + options->has_udot_4x8 = true; + } } static bool @@ -373,6 +377,9 @@ enum dxil_intr { DXIL_INTR_RAW_BUFFER_LOAD = 139, DXIL_INTR_RAW_BUFFER_STORE = 140, + DXIL_INTR_DOT4_ADD_I8_PACKED = 163, + DXIL_INTR_DOT4_ADD_U8_PACKED = 164, + DXIL_INTR_ANNOTATE_HANDLE = 216, DXIL_INTR_CREATE_HANDLE_FROM_BINDING = 217, DXIL_INTR_CREATE_HANDLE_FROM_HEAP = 218, @@ -2499,6 +2506,25 @@ emit_bitfield_insert(struct ntd_context *ctx, nir_alu_instr *alu, return true; } +static bool +emit_dot4add_packed(struct ntd_context *ctx, nir_alu_instr *alu, + enum dxil_intr intr, + const struct dxil_value *src0, + const struct dxil_value *src1, + const struct dxil_value *accum) +{ + const struct dxil_func *f = dxil_get_function(&ctx->mod, "dx.op.dot4AddPacked", DXIL_I32); + if (!f) + return false; + const struct dxil_value *srcs[] = { dxil_module_get_int32_const(&ctx->mod, intr), accum, src0, src1 }; + const struct dxil_value *v = dxil_emit_call(&ctx->mod, f, srcs, ARRAY_SIZE(srcs)); + if (!v) + return false; + + store_alu_dest(ctx, alu, 0, v); + return true; +} + static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *sel, const struct dxil_value *val_true, @@ -2868,6 +2894,9 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu) case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true); case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]); + case nir_op_sdot_4x8_iadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_I8_PACKED, src[0], src[1], src[2]); + case nir_op_udot_4x8_uadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_U8_PACKED, src[0], src[1], src[2]); + case nir_op_i2i1: case nir_op_u2u1: case nir_op_b2i16: @@ -6454,7 +6483,7 @@ optimize_nir(struct nir_shader *s, const struct nir_to_dxil_options *opts) NIR_PASS(progress, s, nir_opt_cse); NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true); NIR_PASS(progress, s, nir_opt_algebraic); - NIR_PASS(progress, s, dxil_nir_lower_x2b); + NIR_PASS(progress, s, dxil_nir_algebraic); if (s->options->lower_int64_options) NIR_PASS(progress, s, nir_lower_int64); NIR_PASS(progress, s, nir_lower_alu);