diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f5fbfcee841..198bebc2744 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -41,6 +41,7 @@ #include "util/u_printf.h" #include "util/mesa-blake3.h" #include "util/bfloat.h" +#include "util/float8.h" #include @@ -2736,15 +2737,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - const glsl_type *dst_type = val->type->type; - const glsl_type *src_type = dst_type; - - const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type); - bool bfloat_src = bfloat_dst; - - if (bfloat_dst) - dst_type = glsl_float_type(); + const glsl_type *org_dst_type = val->type->type; + const glsl_type *org_src_type = org_dst_type; + const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT); unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); @@ -2752,18 +2748,22 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpSConvert: case SpvOpFConvert: - case SpvOpUConvert: { + case SpvOpUConvert: /* We have a different source type in a conversion. */ - src_type = vtn_get_value_type(b, w[4])->type; - bfloat_src = glsl_type_is_bfloat_16(src_type); - if (bfloat_src) - src_type = glsl_float_type(); + org_src_type = vtn_get_value_type(b, w[4])->type; break; - } default: break; }; + const glsl_type *dst_type = org_dst_type; + if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type)) + dst_type = glsl_float_type(); + + const glsl_type *src_type = org_src_type; + if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type)) + src_type = glsl_float_type(); + bool exact; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, src_type, dst_type); @@ -2773,7 +2773,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, */ assert(!exact); - unsigned bit_size = glsl_get_bit_size(src_type); + unsigned bit_size = glsl_get_bit_size(dst_type); nir_const_value src[3][NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < count - 4; i++) { @@ -2783,8 +2783,15 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, /* If this is an unsized source, pull the bit size from the * source; otherwise, we'll use the bit size from the destination. */ - if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) - bit_size = glsl_get_bit_size(src_val->type->type); + if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) { + if (org_src_type != src_type) { + /* Small float conversion. */ + assert(i == 0); + bit_size = glsl_get_bit_size(src_type); + } else { + bit_size = glsl_get_bit_size(src_val->type->type); + } + } unsigned src_comps = nir_op_infos[op].input_sizes[i] ? nir_op_infos[op].input_sizes[i] : @@ -2793,8 +2800,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, unsigned j = swap ? 1 - i : i; for (unsigned c = 0; c < src_comps; c++) { src[j][c] = src_val->constant->values[c]; - if (bfloat_src) + if (glsl_type_is_bfloat_16(org_src_type)) src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16); + else if (glsl_type_is_e4m3fn(org_src_type)) + src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8); + else if (glsl_type_is_e5m2(org_src_type)) + src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8); } } @@ -2825,12 +2836,25 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, num_components, bit_size, srcs, b->shader->info.float_controls_execution_mode); - if (bfloat_dst) { - for (int i = 0; i < num_components; i++) { - const uint16_t b = - _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); - val->constant->values[i] = nir_const_value_for_raw_uint(b, 16); + for (int i = 0; i < num_components; i++) { + uint16_t conv; + if (glsl_type_is_bfloat_16(org_dst_type)) { + conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); + } else if (glsl_type_is_e4m3fn(org_dst_type)) { + if (saturate) + conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32); + else + conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32); + } else if (glsl_type_is_e5m2(org_dst_type)) { + if (saturate) + conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32); + else + conv = _mesa_float_to_e5m2(val->constant->values[i].f32); + } else { + continue; } + + val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type)); } break;