microsoft/compiler: Add some more float16 support

We can support float16 constants, b2f16, and casts to float16.

Reviewed-by: Enrico Galli <enrico.galli@intel.com>
Reviewed-by: Michael Tang <tangm@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10063>
This commit is contained in:
Jesse Natalie 2021-04-06 10:48:26 -07:00 committed by Marge Bot
parent ca08e74525
commit 34c84b6f0e
3 changed files with 59 additions and 0 deletions

View file

@ -1623,6 +1623,30 @@ dxil_module_get_int_const(struct dxil_module *m, intmax_t value,
}
}
const struct dxil_value *
dxil_module_get_float16_const(struct dxil_module *m, uint16_t value)
{
const struct dxil_type *type = get_float16_type(m);
if (!type)
return NULL;
struct dxil_const *c;
LIST_FOR_EACH_ENTRY(c, &m->const_list, head) {
if (c->value.type != type || c->undef)
continue;
if (c->int_value == (uintmax_t)value)
return &c->value;
}
c = create_const(m, type, false);
if (!c)
return NULL;
c->int_value = (uintmax_t)value;
return &c->value;
}
const struct dxil_value *
dxil_module_get_float_const(struct dxil_module *m, float value)
{
@ -2025,6 +2049,15 @@ emit_int_value(struct dxil_module *m, int64_t value)
data, ARRAY_SIZE(data));
}
static bool
emit_float16_value(struct dxil_module *m, uint16_t value)
{
if (!value)
return emit_null_value(m);
uint64_t data = value;
return emit_record_no_abbrev(&m->buf, CST_CODE_FLOAT, &data, 1);
}
static bool
emit_float_value(struct dxil_module *m, float value)
{
@ -2087,6 +2120,10 @@ emit_consts(struct dxil_module *m)
case TYPE_FLOAT:
switch (curr_type->float_bits) {
case 16:
if (!emit_float16_value(m, (uint16_t)(uintmax_t)c->int_value))
return false;
break;
case 32:
if (!emit_float_value(m, c->float_value))
return false;

View file

@ -330,6 +330,9 @@ const struct dxil_value *
dxil_module_get_int_const(struct dxil_module *m, intmax_t value,
unsigned bit_size);
const struct dxil_value *
dxil_module_get_float16_const(struct dxil_module *m, uint16_t);
const struct dxil_value *
dxil_module_get_float_const(struct dxil_module *m, float value);

View file

@ -1555,11 +1555,13 @@ get_cast_op(nir_alu_instr *alu)
return DXIL_CAST_FPTOUI;
/* int -> float */
case nir_op_i2f16:
case nir_op_i2f32:
case nir_op_i2f64:
return DXIL_CAST_SITOFP;
/* uint -> float */
case nir_op_u2f16:
case nir_op_u2f32:
case nir_op_u2f64:
return DXIL_CAST_UITOFP;
@ -1735,6 +1737,22 @@ static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
return true;
}
static bool
emit_b2f16(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
{
assert(val);
struct dxil_module *m = &ctx->mod;
const struct dxil_value *c1 = dxil_module_get_float16_const(m, 0x3C00);
const struct dxil_value *c0 = dxil_module_get_float16_const(m, 0);
if (!c0 || !c1)
return false;
return emit_select(ctx, alu, val, c1, c0);
}
static bool
emit_b2f32(struct ntd_context *ctx, nir_alu_instr *alu, const struct dxil_value *val)
{
@ -2056,6 +2074,7 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
return emit_cast(ctx, alu, src[0]);
case nir_op_f2b32: return emit_f2b32(ctx, alu, src[0]);
case nir_op_b2f16: return emit_b2f16(ctx, alu, src[0]);
case nir_op_b2f32: return emit_b2f32(ctx, alu, src[0]);
default:
NIR_INSTR_UNSUPPORTED(&alu->instr);