spirv: Add bfloat16 support to SpecConstantOp

Handle bfloat16 by converting sources to float, performing the
operation, and converting result back to bfloat16 if needed.  This is
done because not all ALU ops have a `bf` version in NIR.

Reviewed-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-03-19 09:12:38 -07:00 committed by Marge Bot
parent dc8074683d
commit 90e1b12890

View file

@ -40,6 +40,7 @@
#include "util/u_debug.h"
#include "util/u_printf.h"
#include "util/mesa-blake3.h"
#include "util/bfloat.h"
#include <stdio.h>
@ -2697,6 +2698,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
const glsl_type *dst_type = val->type->type;
const glsl_type *src_type = dst_type;
const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type);
bool bfloat_src = bfloat_dst;
if (bfloat_dst)
dst_type = glsl_float_type();
unsigned num_components = glsl_get_vector_elements(val->type->type);
vtn_assert(count <= 7);
@ -2704,10 +2711,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
switch (opcode) {
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpUConvert:
case SpvOpUConvert: {
/* We have a different source type in a conversion. */
src_type = vtn_get_value_type(b, w[4])->type;
bfloat_src = glsl_type_is_bfloat_16(src_type);
if (bfloat_src)
src_type = glsl_float_type();
break;
}
default:
break;
};
@ -2721,7 +2732,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
*/
assert(!exact);
unsigned bit_size = glsl_get_bit_size(val->type->type);
unsigned bit_size = glsl_get_bit_size(src_type);
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < count - 4; i++) {
@ -2739,8 +2750,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
num_components;
unsigned j = swap ? 1 - i : i;
for (unsigned c = 0; c < src_comps; c++)
for (unsigned c = 0; c < src_comps; c++) {
src[j][c] = src_val->constant->values[c];
if (bfloat_src)
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
}
}
/* fix up fixed size sources */
@ -2769,6 +2783,16 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
nir_eval_const_opcode(op, val->constant->values,
num_components, bit_size, srcs,
b->shader->info.float_controls_execution_mode);
if (bfloat_dst) {
for (int i = 0; i < num_components; i++) {
/* Ensure the pad bits are zeroed by fully assigning the value. */
const uint16_t b =
_mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
val->constant->values[i] = (nir_const_value){ .u16 = b };
}
}
break;
} /* default */
}