spirv: Implement Conversions to/from bfloat16

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-02-28 01:03:50 -08:00 committed by Marge Bot
parent 90e1b12890
commit 2807097690

View file

@ -668,8 +668,35 @@ static nir_def *
vtn_handle_convert(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *dest_val,
const struct glsl_type *glsl_dest_type,
const struct glsl_type *glsl_src_type,
nir_def *src)
{
/* From SPV_KHR_bfloat16 extension:
*
* Conversions to or from floating-point values with the `BFloat16KHR`
* encoding first convert the source value to IEEE754 binary32, and then
* from IEEE754 binary32 to the target format.
*
* For now we are limiting exposure of bfloat16 in NIR, so apply the
* extra conversions directly here.
*/
if (glsl_type_is_bfloat_16(glsl_src_type)) {
nir_def *src_as_float = nir_bf2f(&b->nb, src);
if (glsl_type_is_float(glsl_dest_type))
return src_as_float;
return vtn_handle_convert(b, opcode, dest_val, glsl_dest_type,
glsl_float_type(), src_as_float);
} else if (glsl_type_is_bfloat_16(glsl_dest_type)) {
nir_def *src_as_float;
if (glsl_type_is_float(glsl_src_type))
src_as_float = src;
else
src_as_float = vtn_handle_convert(b, opcode, dest_val, glsl_float_type(),
glsl_src_type, src);
return nir_f2bf(&b->nb, src_as_float);
}
/* Use bit_size from NIR source instead of from the original src type,
* to account for mediump_16bit. See vtn_handle_alu() for details.
*/
@ -960,7 +987,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
case SpvOpFConvert:
case SpvOpSatConvertSToU:
case SpvOpSatConvertUToS:
dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type, src[0]);
dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type,
vtn_src[0]->type, src[0]);
break;
case SpvOpBitFieldInsert: