spirv: construct a bfloat16 from the given SPIR-V bitsize and encoding

Signed-off-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Rohan Garg 2024-09-11 12:11:51 +02:00 committed by Marge Bot
parent fb6ae2eac1
commit dc8074683d

View file

@ -1886,10 +1886,26 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
case SpvOpTypeFloat: {
int bit_size = w[2];
val->type->base_type = vtn_base_type_scalar;
vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64,
"Invalid float bit size: %u", bit_size);
val->type->type = glsl_floatN_t_type(bit_size);
val->type->length = 1;
int32_t encoding = count > 3 ? w[3] : -1;
switch (encoding) {
case -1:
/* No encoding specified, it is a regular FP. */
vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64,
"Invalid float bit size: %u", bit_size);
val->type->type = glsl_floatN_t_type(bit_size);
break;
case SpvFPEncodingBFloat16KHR:
vtn_fail_if(bit_size != 16,
"Invalid Bfloat16 bit size: %u", bit_size);
val->type->type = glsl_bfloatN_t_type(bit_size);
break;
default:
vtn_fail("Unsupported OpTypeFloat encoding: %d", encoding);
}
break;
}