microsoft/compiler: Handle mediump

Instead of treating all 16-bit values as "native 16-bit types,"
differentiate between concrete casts and mediump casts, where the
former requires native 16-bit types, and the latter only requires
DXIL min-precision. Additionally, UBO/SSBO loads/stores require
native 16-bit types.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23344>
This commit is contained in:
Jesse Natalie 2023-05-31 13:14:58 -07:00 committed by Marge Bot
parent 7371c9a2a8
commit ea68135ed1

View file

@ -156,6 +156,7 @@ nir_options = {
.lower_device_index_to_zero = true,
.linker_ignore_precision = true,
.support_16bit_alu = true,
.preserve_mediump = true,
};
const nir_shader_compiler_options*
@ -2139,7 +2140,7 @@ store_dest(struct ntd_context *ctx, nir_dest *dest, unsigned chan,
ctx->mod.feats.doubles = true;
if (type == ctx->mod.float16_type ||
type == ctx->mod.int16_type)
ctx->mod.feats.native_low_precision = true;
ctx->mod.feats.min_precision = true;
if (type == ctx->mod.int64_type)
ctx->mod.feats.int64_ops = true;
store_dest_value(ctx, dest, chan, value);
@ -2309,6 +2310,7 @@ get_cast_op(nir_alu_instr *alu)
/* float -> float */
case nir_op_f2f16_rtz:
case nir_op_f2f16:
case nir_op_f2fmp:
case nir_op_f2f32:
case nir_op_f2f64:
assert(dst_bits != src_bits);
@ -2320,6 +2322,7 @@ get_cast_op(nir_alu_instr *alu)
/* int -> int */
case nir_op_i2i1:
case nir_op_i2i16:
case nir_op_i2imp:
case nir_op_i2i32:
case nir_op_i2i64:
assert(dst_bits != src_bits);
@ -2341,24 +2344,28 @@ get_cast_op(nir_alu_instr *alu)
/* float -> int */
case nir_op_f2i16:
case nir_op_f2imp:
case nir_op_f2i32:
case nir_op_f2i64:
return DXIL_CAST_FPTOSI;
/* float -> uint */
case nir_op_f2u16:
case nir_op_f2ump:
case nir_op_f2u32:
case nir_op_f2u64:
return DXIL_CAST_FPTOUI;
/* int -> float */
case nir_op_i2f16:
case nir_op_i2fmp:
case nir_op_i2f32:
case nir_op_i2f64:
return DXIL_CAST_SITOFP;
/* uint -> float */
case nir_op_u2f16:
case nir_op_u2fmp:
case nir_op_u2f32:
case nir_op_u2f64:
return DXIL_CAST_UITOFP;
@ -2420,6 +2427,20 @@ emit_cast(struct ntd_context *ctx, nir_alu_instr *alu,
break;
}
if (nir_dest_bit_size(alu->dest.dest) == 16) {
switch (alu->op) {
case nir_op_f2fmp:
case nir_op_i2imp:
case nir_op_f2imp:
case nir_op_f2ump:
case nir_op_i2fmp:
case nir_op_u2fmp:
break;
default:
ctx->mod.feats.native_low_precision = true;
}
}
const struct dxil_value *v = dxil_emit_cast(&ctx->mod, opcode, type,
value);
if (!v)
@ -2960,13 +2981,19 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
case nir_op_u2u1:
case nir_op_b2i16:
case nir_op_i2i16:
case nir_op_i2imp:
case nir_op_f2i16:
case nir_op_f2imp:
case nir_op_f2u16:
case nir_op_f2ump:
case nir_op_u2u16:
case nir_op_u2f16:
case nir_op_u2fmp:
case nir_op_i2f16:
case nir_op_i2fmp:
case nir_op_f2f16_rtz:
case nir_op_f2f16:
case nir_op_f2fmp:
case nir_op_b2i32:
case nir_op_f2f32:
case nir_op_f2i32:
@ -3428,6 +3455,8 @@ emit_load_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
return false;
store_dest(ctx, &intr->dest, i, val);
}
if (nir_dest_bit_size(intr->dest) == 16)
ctx->mod.feats.native_low_precision = true;
return true;
}
@ -3442,6 +3471,8 @@ emit_store_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
unsigned num_components = nir_src_num_components(intr->src[0]);
assert(num_components <= 4);
if (nir_src_bit_size(intr->src[0]) == 16)
ctx->mod.feats.native_low_precision = true;
nir_alu_type type =
dxil_type_to_nir_type(dxil_value_get_type(get_src_ssa(ctx, intr->src[0].ssa, 0)));
@ -3623,6 +3654,8 @@ emit_load_ubo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
const struct dxil_value *retval = dxil_emit_extractval(&ctx->mod, agg, i);
store_dest(ctx, &intr->dest, i, retval);
}
if (nir_dest_bit_size(intr->dest) == 16)
ctx->mod.feats.native_low_precision = true;
return true;
}
@ -3648,6 +3681,8 @@ emit_load_ubo_dxil(struct ntd_context *ctx, nir_intrinsic_instr *intr)
store_dest(ctx, &intr->dest, i,
dxil_emit_extractval(&ctx->mod, agg, i));
if (nir_dest_bit_size(intr->dest) == 16)
ctx->mod.feats.native_low_precision = true;
return true;
}
@ -5143,7 +5178,7 @@ get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dx
if (type == mod->float32_type) return dxil_module_get_float_const(mod, c->f32);
if (type == mod->int32_type) return dxil_module_get_int32_const(mod, c->i32);
if (type == mod->int16_type) {
mod->feats.native_low_precision = true;
mod->feats.min_precision = true;
return dxil_module_get_int16_const(mod, c->i16);
}
if (type == mod->int64_type) {
@ -5151,7 +5186,7 @@ get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dx
return dxil_module_get_int64_const(mod, c->i64);
}
if (type == mod->float16_type) {
mod->feats.native_low_precision = true;
mod->feats.min_precision = true;
return dxil_module_get_float16_const(mod, c->u16);
}
if (type == mod->float64_type) {
@ -6750,6 +6785,9 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
struct dxil_container container;
dxil_container_init(&container);
/* Native low precision disables min-precision */
if (ctx->mod.feats.native_low_precision)
ctx->mod.feats.min_precision = false;
if (!dxil_container_add_features(&container, &ctx->mod.feats)) {
debug_printf("D3D12: dxil_container_add_features failed\n");
retval = false;