mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-17 16:08:06 +02:00
microsoft/compiler: Enable packed dot product intrinsics for SM6.4+
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22952>
This commit is contained in:
parent
217bbdc4fd
commit
a6ea08c542
4 changed files with 41 additions and 5 deletions
|
|
@ -110,6 +110,7 @@ static struct predefined_func_descr predefined_funcs[] = {
|
|||
{"dx.op.wavePrefixOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
|
||||
{"dx.op.dot4AddPacked", "i", "iiii", DXIL_ATTR_KIND_READ_NONE},
|
||||
};
|
||||
|
||||
struct func_descr {
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ extern "C" {
|
|||
|
||||
bool dxil_nir_lower_8bit_conv(nir_shader *shader);
|
||||
bool dxil_nir_lower_16bit_conv(nir_shader *shader);
|
||||
bool dxil_nir_lower_x2b(nir_shader *shader);
|
||||
bool dxil_nir_algebraic(nir_shader *shader);
|
||||
bool dxil_nir_lower_fquantize2f16(nir_shader *shader);
|
||||
bool dxil_nir_lower_ubo_to_temp(nir_shader *shader);
|
||||
struct dxil_nir_lower_loads_stores_options {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ import sys
|
|||
import math
|
||||
|
||||
a = 'a'
|
||||
b = 'b'
|
||||
c = 'c'
|
||||
|
||||
# The nir_lower_bit_size() pass gets rid of all 8bit ALUs but insert new u2u8
|
||||
# and i2i8 operations to convert the result back to the original type after the
|
||||
|
|
@ -91,9 +93,13 @@ def remove_unsupported_casts(arr, bit_size, mask, max_unsigned_float, min_signed
|
|||
remove_unsupported_casts(no_8bit_conv, 8, 0xff, 255.0, -128.0, 127.0)
|
||||
remove_unsupported_casts(no_16bit_conv, 16, 0xffff, 65535.0, -32768.0, 32767.0)
|
||||
|
||||
lower_x2b = [
|
||||
algebraic_ops = [
|
||||
(('b2b32', 'a'), ('b2i32', 'a')),
|
||||
(('b2b1', 'a'), ('ine', ('b2i32', a), 0)),
|
||||
|
||||
# We don't support the saturating versions of these
|
||||
(('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c)),
|
||||
(('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c)),
|
||||
]
|
||||
|
||||
no_16bit_conv += [
|
||||
|
|
@ -118,8 +124,8 @@ def run():
|
|||
no_8bit_conv).render())
|
||||
print(nir_algebraic.AlgebraicPass("dxil_nir_lower_16bit_conv",
|
||||
no_16bit_conv).render())
|
||||
print(nir_algebraic.AlgebraicPass("dxil_nir_lower_x2b",
|
||||
lower_x2b).render())
|
||||
print(nir_algebraic.AlgebraicPass("dxil_nir_algebraic",
|
||||
algebraic_ops).render())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -179,6 +179,10 @@ dxil_get_nir_compiler_options(nir_shader_compiler_options *options,
|
|||
options->lower_doubles_options = ~0;
|
||||
if ((supported_int_sizes & 16) && (supported_float_sizes & 16))
|
||||
options->support_16bit_alu = true;
|
||||
if (shader_model_max >= SHADER_MODEL_6_4) {
|
||||
options->has_sdot_4x8 = true;
|
||||
options->has_udot_4x8 = true;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
|
|
@ -373,6 +377,9 @@ enum dxil_intr {
|
|||
DXIL_INTR_RAW_BUFFER_LOAD = 139,
|
||||
DXIL_INTR_RAW_BUFFER_STORE = 140,
|
||||
|
||||
DXIL_INTR_DOT4_ADD_I8_PACKED = 163,
|
||||
DXIL_INTR_DOT4_ADD_U8_PACKED = 164,
|
||||
|
||||
DXIL_INTR_ANNOTATE_HANDLE = 216,
|
||||
DXIL_INTR_CREATE_HANDLE_FROM_BINDING = 217,
|
||||
DXIL_INTR_CREATE_HANDLE_FROM_HEAP = 218,
|
||||
|
|
@ -2499,6 +2506,25 @@ emit_bitfield_insert(struct ntd_context *ctx, nir_alu_instr *alu,
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
emit_dot4add_packed(struct ntd_context *ctx, nir_alu_instr *alu,
|
||||
enum dxil_intr intr,
|
||||
const struct dxil_value *src0,
|
||||
const struct dxil_value *src1,
|
||||
const struct dxil_value *accum)
|
||||
{
|
||||
const struct dxil_func *f = dxil_get_function(&ctx->mod, "dx.op.dot4AddPacked", DXIL_I32);
|
||||
if (!f)
|
||||
return false;
|
||||
const struct dxil_value *srcs[] = { dxil_module_get_int32_const(&ctx->mod, intr), accum, src0, src1 };
|
||||
const struct dxil_value *v = dxil_emit_call(&ctx->mod, f, srcs, ARRAY_SIZE(srcs));
|
||||
if (!v)
|
||||
return false;
|
||||
|
||||
store_alu_dest(ctx, alu, 0, v);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
|
||||
const struct dxil_value *sel,
|
||||
const struct dxil_value *val_true,
|
||||
|
|
@ -2868,6 +2894,9 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
|
|||
case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true);
|
||||
case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]);
|
||||
|
||||
case nir_op_sdot_4x8_iadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_I8_PACKED, src[0], src[1], src[2]);
|
||||
case nir_op_udot_4x8_uadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_U8_PACKED, src[0], src[1], src[2]);
|
||||
|
||||
case nir_op_i2i1:
|
||||
case nir_op_u2u1:
|
||||
case nir_op_b2i16:
|
||||
|
|
@ -6454,7 +6483,7 @@ optimize_nir(struct nir_shader *s, const struct nir_to_dxil_options *opts)
|
|||
NIR_PASS(progress, s, nir_opt_cse);
|
||||
NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true);
|
||||
NIR_PASS(progress, s, nir_opt_algebraic);
|
||||
NIR_PASS(progress, s, dxil_nir_lower_x2b);
|
||||
NIR_PASS(progress, s, dxil_nir_algebraic);
|
||||
if (s->options->lower_int64_options)
|
||||
NIR_PASS(progress, s, nir_lower_int64);
|
||||
NIR_PASS(progress, s, nir_lower_alu);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue