From 28070976904a087ce304e32deed7bb08619df626 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Fri, 28 Feb 2025 01:03:50 -0800 Subject: [PATCH] spirv: Implement Conversions to/from bfloat16 Reviewed-by: Ian Romanick Reviewed-by: Rohan Garg Reviewed-by: Georg Lehmann Part-of: --- src/compiler/spirv/vtn_alu.c | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index f0b459a53d0..544e820d1ea 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -668,8 +668,35 @@ static nir_def * vtn_handle_convert(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val, const struct glsl_type *glsl_dest_type, + const struct glsl_type *glsl_src_type, nir_def *src) { + /* From SPV_KHR_bfloat16 extension: + * + * Conversions to or from floating-point values with the `BFloat16KHR` + * encoding first convert the source value to IEEE754 binary32, and then + * from IEEE754 binary32 to the target format. + * + * For now we are limiting exposure of bfloat16 in NIR, so apply the + * extra conversions directly here. + */ + if (glsl_type_is_bfloat_16(glsl_src_type)) { + nir_def *src_as_float = nir_bf2f(&b->nb, src); + if (glsl_type_is_float(glsl_dest_type)) + return src_as_float; + return vtn_handle_convert(b, opcode, dest_val, glsl_dest_type, + glsl_float_type(), src_as_float); + + } else if (glsl_type_is_bfloat_16(glsl_dest_type)) { + nir_def *src_as_float; + if (glsl_type_is_float(glsl_src_type)) + src_as_float = src; + else + src_as_float = vtn_handle_convert(b, opcode, dest_val, glsl_float_type(), + glsl_src_type, src); + return nir_f2bf(&b->nb, src_as_float); + } + /* Use bit_size from NIR source instead of from the original src type, * to account for mediump_16bit. See vtn_handle_alu() for details. */ @@ -960,7 +987,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFConvert: case SpvOpSatConvertSToU: case SpvOpSatConvertUToS: - dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type, src[0]); + dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type, + vtn_src[0]->type, src[0]); break; case SpvOpBitFieldInsert: