mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-12 17:40:32 +01:00
spirv: support float8 spec constant op
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35434>
This commit is contained in:
parent
55b2f4958f
commit
51d3c4c889
1 changed files with 47 additions and 23 deletions
|
|
@ -41,6 +41,7 @@
|
|||
#include "util/u_printf.h"
|
||||
#include "util/mesa-blake3.h"
|
||||
#include "util/bfloat.h"
|
||||
#include "util/float8.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -2736,15 +2737,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
default: {
|
||||
bool swap;
|
||||
|
||||
const glsl_type *dst_type = val->type->type;
|
||||
const glsl_type *src_type = dst_type;
|
||||
|
||||
const bool bfloat_dst = glsl_type_is_bfloat_16(dst_type);
|
||||
bool bfloat_src = bfloat_dst;
|
||||
|
||||
if (bfloat_dst)
|
||||
dst_type = glsl_float_type();
|
||||
const glsl_type *org_dst_type = val->type->type;
|
||||
const glsl_type *org_src_type = org_dst_type;
|
||||
|
||||
const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT);
|
||||
unsigned num_components = glsl_get_vector_elements(val->type->type);
|
||||
|
||||
vtn_assert(count <= 7);
|
||||
|
|
@ -2752,18 +2748,22 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
switch (opcode) {
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert:
|
||||
case SpvOpUConvert: {
|
||||
case SpvOpUConvert:
|
||||
/* We have a different source type in a conversion. */
|
||||
src_type = vtn_get_value_type(b, w[4])->type;
|
||||
bfloat_src = glsl_type_is_bfloat_16(src_type);
|
||||
if (bfloat_src)
|
||||
src_type = glsl_float_type();
|
||||
org_src_type = vtn_get_value_type(b, w[4])->type;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
};
|
||||
|
||||
const glsl_type *dst_type = org_dst_type;
|
||||
if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type))
|
||||
dst_type = glsl_float_type();
|
||||
|
||||
const glsl_type *src_type = org_src_type;
|
||||
if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type))
|
||||
src_type = glsl_float_type();
|
||||
|
||||
bool exact;
|
||||
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
|
||||
src_type, dst_type);
|
||||
|
|
@ -2773,7 +2773,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
*/
|
||||
assert(!exact);
|
||||
|
||||
unsigned bit_size = glsl_get_bit_size(src_type);
|
||||
unsigned bit_size = glsl_get_bit_size(dst_type);
|
||||
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (unsigned i = 0; i < count - 4; i++) {
|
||||
|
|
@ -2783,8 +2783,15 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
/* If this is an unsized source, pull the bit size from the
|
||||
* source; otherwise, we'll use the bit size from the destination.
|
||||
*/
|
||||
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]))
|
||||
bit_size = glsl_get_bit_size(src_val->type->type);
|
||||
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) {
|
||||
if (org_src_type != src_type) {
|
||||
/* Small float conversion. */
|
||||
assert(i == 0);
|
||||
bit_size = glsl_get_bit_size(src_type);
|
||||
} else {
|
||||
bit_size = glsl_get_bit_size(src_val->type->type);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned src_comps = nir_op_infos[op].input_sizes[i] ?
|
||||
nir_op_infos[op].input_sizes[i] :
|
||||
|
|
@ -2793,8 +2800,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
unsigned j = swap ? 1 - i : i;
|
||||
for (unsigned c = 0; c < src_comps; c++) {
|
||||
src[j][c] = src_val->constant->values[c];
|
||||
if (bfloat_src)
|
||||
if (glsl_type_is_bfloat_16(org_src_type))
|
||||
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
|
||||
else if (glsl_type_is_e4m3fn(org_src_type))
|
||||
src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8);
|
||||
else if (glsl_type_is_e5m2(org_src_type))
|
||||
src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2825,12 +2836,25 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
num_components, bit_size, srcs,
|
||||
b->shader->info.float_controls_execution_mode);
|
||||
|
||||
if (bfloat_dst) {
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
const uint16_t b =
|
||||
_mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
|
||||
val->constant->values[i] = nir_const_value_for_raw_uint(b, 16);
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
uint16_t conv;
|
||||
if (glsl_type_is_bfloat_16(org_dst_type)) {
|
||||
conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
|
||||
} else if (glsl_type_is_e4m3fn(org_dst_type)) {
|
||||
if (saturate)
|
||||
conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32);
|
||||
else
|
||||
conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32);
|
||||
} else if (glsl_type_is_e5m2(org_dst_type)) {
|
||||
if (saturate)
|
||||
conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32);
|
||||
else
|
||||
conv = _mesa_float_to_e5m2(val->constant->values[i].f32);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type));
|
||||
}
|
||||
|
||||
break;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue