mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-20 22:30:12 +01:00
spirv: Add bfloat16 support to SpecConstantOp
Handle bfloat16 by converting sources to float, performing the operation, and converting result back to bfloat16 if needed. This is done because not all ALU ops have a `bf` version in NIR. Reviewed-by: Rohan Garg <rohan.garg@intel.com> Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
parent
dc8074683d
commit
90e1b12890
1 changed files with 27 additions and 3 deletions
|
|
@ -40,6 +40,7 @@
|
|||
#include "util/u_debug.h"
|
||||
#include "util/u_printf.h"
|
||||
#include "util/mesa-blake3.h"
|
||||
#include "util/bfloat.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -2697,6 +2698,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
const glsl_type *dst_type = val->type->type;
|
||||
const glsl_type *src_type = dst_type;
|
||||
|
||||
const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type);
|
||||
bool bfloat_src = bfloat_dst;
|
||||
|
||||
if (bfloat_dst)
|
||||
dst_type = glsl_float_type();
|
||||
|
||||
unsigned num_components = glsl_get_vector_elements(val->type->type);
|
||||
|
||||
vtn_assert(count <= 7);
|
||||
|
|
@ -2704,10 +2711,14 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
switch (opcode) {
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert:
|
||||
case SpvOpUConvert:
|
||||
case SpvOpUConvert: {
|
||||
/* We have a different source type in a conversion. */
|
||||
src_type = vtn_get_value_type(b, w[4])->type;
|
||||
bfloat_src = glsl_type_is_bfloat_16(src_type);
|
||||
if (bfloat_src)
|
||||
src_type = glsl_float_type();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
};
|
||||
|
|
@ -2721,7 +2732,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
*/
|
||||
assert(!exact);
|
||||
|
||||
unsigned bit_size = glsl_get_bit_size(val->type->type);
|
||||
unsigned bit_size = glsl_get_bit_size(src_type);
|
||||
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (unsigned i = 0; i < count - 4; i++) {
|
||||
|
|
@ -2739,8 +2750,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
num_components;
|
||||
|
||||
unsigned j = swap ? 1 - i : i;
|
||||
for (unsigned c = 0; c < src_comps; c++)
|
||||
for (unsigned c = 0; c < src_comps; c++) {
|
||||
src[j][c] = src_val->constant->values[c];
|
||||
if (bfloat_src)
|
||||
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
|
||||
}
|
||||
}
|
||||
|
||||
/* fix up fixed size sources */
|
||||
|
|
@ -2769,6 +2783,16 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
nir_eval_const_opcode(op, val->constant->values,
|
||||
num_components, bit_size, srcs,
|
||||
b->shader->info.float_controls_execution_mode);
|
||||
|
||||
if (bfloat_dst) {
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
/* Ensure the pad bits are zeroed by fully assigning the value. */
|
||||
const uint16_t b =
|
||||
_mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
|
||||
val->constant->values[i] = (nir_const_value){ .u16 = b };
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
} /* default */
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue