spirv: support float8 spec constant op

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35434>
This commit is contained in:
Georg Lehmann 2025-06-08 14:04:56 +02:00 committed by Marge Bot
parent 55b2f4958f
commit 51d3c4c889

View file

@ -41,6 +41,7 @@
#include "util/u_printf.h"
#include "util/mesa-blake3.h"
#include "util/bfloat.h"
#include "util/float8.h"
#include <stdio.h>
@ -2736,15 +2737,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
default: {
bool swap;
const glsl_type *dst_type = val->type->type;
const glsl_type *src_type = dst_type;
const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type);
bool bfloat_src = bfloat_dst;
if (bfloat_dst)
dst_type = glsl_float_type();
const glsl_type *org_dst_type = val->type->type;
const glsl_type *org_src_type = org_dst_type;
const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT);
unsigned num_components = glsl_get_vector_elements(val->type->type);
vtn_assert(count <= 7);
@ -2752,18 +2748,22 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
switch (opcode) {
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpUConvert: {
case SpvOpUConvert:
/* We have a different source type in a conversion. */
src_type = vtn_get_value_type(b, w[4])->type;
bfloat_src = glsl_type_is_bfloat_16(src_type);
if (bfloat_src)
src_type = glsl_float_type();
org_src_type = vtn_get_value_type(b, w[4])->type;
break;
}
default:
break;
};
const glsl_type *dst_type = org_dst_type;
if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type))
dst_type = glsl_float_type();
const glsl_type *src_type = org_src_type;
if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type))
src_type = glsl_float_type();
bool exact;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
src_type, dst_type);
@ -2773,7 +2773,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
*/
assert(!exact);
unsigned bit_size = glsl_get_bit_size(src_type);
unsigned bit_size = glsl_get_bit_size(dst_type);
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < count - 4; i++) {
@ -2783,8 +2783,15 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
/* If this is an unsized source, pull the bit size from the
* source; otherwise, we'll use the bit size from the destination.
*/
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]))
bit_size = glsl_get_bit_size(src_val->type->type);
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) {
if (org_src_type != src_type) {
/* Small float conversion. */
assert(i == 0);
bit_size = glsl_get_bit_size(src_type);
} else {
bit_size = glsl_get_bit_size(src_val->type->type);
}
}
unsigned src_comps = nir_op_infos[op].input_sizes[i] ?
nir_op_infos[op].input_sizes[i] :
@ -2793,8 +2800,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
unsigned j = swap ? 1 - i : i;
for (unsigned c = 0; c < src_comps; c++) {
src[j][c] = src_val->constant->values[c];
if (bfloat_src)
if (glsl_type_is_bfloat_16(org_src_type))
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
else if (glsl_type_is_e4m3fn(org_src_type))
src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8);
else if (glsl_type_is_e5m2(org_src_type))
src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8);
}
}
@ -2825,12 +2836,25 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
num_components, bit_size, srcs,
b->shader->info.float_controls_execution_mode);
if (bfloat_dst) {
for (int i = 0; i < num_components; i++) {
const uint16_t b =
_mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
val->constant->values[i] = nir_const_value_for_raw_uint(b, 16);
for (int i = 0; i < num_components; i++) {
uint16_t conv;
if (glsl_type_is_bfloat_16(org_dst_type)) {
conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
} else if (glsl_type_is_e4m3fn(org_dst_type)) {
if (saturate)
conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32);
else
conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32);
} else if (glsl_type_is_e5m2(org_dst_type)) {
if (saturate)
conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32);
else
conv = _mesa_float_to_e5m2(val->constant->values[i].f32);
} else {
continue;
}
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type));
}
break;