microsoft/compiler: Fix handling of fp16-in-32bit-val ops to handle high bits

Reviewed-by: Sil Vilerino <sivileri@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14342>
This commit is contained in:
Jesse Natalie 2021-12-30 17:50:51 -08:00 committed by Marge Bot
parent 20e374f4a3
commit 02fc28625f

View file

@ -104,6 +104,11 @@ nir_options = {
.lower_pack_32_2x16_split = true,
.lower_unpack_64_2x32_split = true,
.lower_unpack_32_2x16_split = true,
.lower_unpack_half_2x16 = true,
.lower_unpack_snorm_2x16 = true,
.lower_unpack_snorm_4x8 = true,
.lower_unpack_unorm_2x16 = true,
.lower_unpack_unorm_4x8 = true,
.has_fsub = true,
.has_isub = true,
.use_scoped_barrier = true,
@ -1953,8 +1958,15 @@ emit_ufind_msb(struct ntd_context *ctx, nir_alu_instr *alu,
}
static bool
emit_f16tof32(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
emit_f16tof32(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val, bool shift)
{
if (shift) {
val = dxil_emit_binop(&ctx->mod, DXIL_BINOP_LSHR, val,
dxil_module_get_int32_const(&ctx->mod, 16), 0);
if (!val)
return false;
}
const struct dxil_func *func = dxil_get_function(&ctx->mod,
"dx.op.legacyF16ToF32",
DXIL_NONE);
@ -1978,7 +1990,7 @@ emit_f16tof32(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_val
}
static bool
emit_f32tof16(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
emit_f32tof16(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val0, const struct dxil_value *val1)
{
const struct dxil_func *func = dxil_get_function(&ctx->mod,
"dx.op.legacyF32ToF16",
@ -1992,12 +2004,29 @@ emit_f32tof16(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_val
const struct dxil_value *args[] = {
opcode,
val
val0
};
const struct dxil_value *v = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!v)
return false;
if (!nir_src_is_const(alu->src[1].src) || nir_src_as_int(alu->src[1].src) != 0) {
args[1] = val1;
const struct dxil_value *v_high = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!v_high)
return false;
v_high = dxil_emit_binop(&ctx->mod, DXIL_BINOP_SHL, v_high,
dxil_module_get_int32_const(&ctx->mod, 16), 0);
if (!v_high)
return false;
v = dxil_emit_binop(&ctx->mod, DXIL_BINOP_OR, v, v_high, 0);
if (!v)
return false;
}
store_alu_dest(ctx, alu, 0, v);
return true;
}
@ -2192,8 +2221,9 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
case nir_op_fmin: return emit_binary_intin(ctx, alu, DXIL_INTR_FMIN, src[0], src[1]);
case nir_op_ffma: return emit_tertiary_intin(ctx, alu, DXIL_INTR_FMA, src[0], src[1], src[2]);
case nir_op_unpack_half_2x16_split_x: return emit_f16tof32(ctx, alu, src[0]);
case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0]);
case nir_op_unpack_half_2x16_split_x: return emit_f16tof32(ctx, alu, src[0], false);
case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true);
case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]);
case nir_op_b2i16:
case nir_op_i2i16: