spirv: Fix spec constant to handle Select for non-native floats

There was an assumption that if the instruction had non-native float
as a source, the first source would have such type.  This doesn't
hold for Select, and the code failed in two ways

- The boolean source of Select was being converted to the non-native
  float type.

- The loop that resolves the bit-size for unsized operands would
  trip at `assert(i == 0)` because Select has more than one source.

Re-organize the code to track the types of the sources independently,
and fix both issues above.

Fixes: 90e1b12890 ("spirv: Add bfloat16 support to SpecConstantOp")
Fixes: 51d3c4c889 ("spirv: support float8 spec constant op")
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
(cherry picked from commit 6affcb43a7)

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40359>
This commit is contained in:
Caio Oliveira 2026-02-28 10:08:42 -08:00 committed by Eric Engestrom
parent 4588b025c8
commit b2a34da82f
3 changed files with 48 additions and 35 deletions

View file

@ -3014,7 +3014,7 @@
"description": "spirv: Fix spec constant to handle Select for non-native floats",
"nominated": true,
"nomination_type": 2,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": "90e1b128903cabfe4fcfb5ae52cf46d5ddbf1189",
"notes": null

View file

@ -676,6 +676,14 @@ glsl_type_is_e5m2(const glsl_type *t)
return t->base_type == GLSL_TYPE_FLOAT_E5M2;
}
static inline bool
glsl_type_is_nonnative_float(const glsl_type *t)
{
return t->base_type == GLSL_TYPE_BFLOAT16 ||
t->base_type == GLSL_TYPE_FLOAT_E4M3FN ||
t->base_type == GLSL_TYPE_FLOAT_E5M2;
}
static inline bool
glsl_type_is_int_16_32_64(const glsl_type *t)
{

View file

@ -2865,61 +2865,66 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
default: {
bool swap;
const glsl_type *org_dst_type = val->type->type;
const glsl_type *org_src_type = org_dst_type;
const glsl_type *dst_type = val->type->type;
const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT);
unsigned num_components = glsl_get_vector_elements(val->type->type);
vtn_assert(count <= 7);
const unsigned src_count = count - 4;
struct vtn_value *src_val[3] = {0};
const glsl_type *src_type[3] = {0};
for (unsigned i = 0; i < src_count; i++) {
src_val[i] = vtn_value(b, w[4 + i], vtn_value_type_constant);
src_type[i] = src_val[i]->type->type;
}
unsigned conv_src_bit_size;
switch (opcode) {
case SpvOpConvertFToU:
case SpvOpConvertFToS:
case SpvOpConvertSToF:
case SpvOpConvertUToF:
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpUConvert:
/* We have a different source type in a conversion. */
org_src_type = vtn_get_value_type(b, w[4])->type;
conv_src_bit_size =
glsl_type_is_nonnative_float(src_type[0]) ? 32 : glsl_get_bit_size(src_type[0]);
break;
default:
/* When picking ALU ops, bit-size is only used for Convert
* operations.
*/
conv_src_bit_size = 0;
break;
};
const glsl_type *dst_type = org_dst_type;
if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type))
dst_type = glsl_float_type();
const glsl_type *src_type = org_src_type;
if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type))
src_type = glsl_float_type();
const unsigned dst_bit_size =
glsl_type_is_nonnative_float(dst_type) ? 32 : glsl_get_bit_size(dst_type);
bool exact;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
glsl_get_bit_size(src_type),
glsl_get_bit_size(dst_type));
conv_src_bit_size, dst_bit_size);
/* No SPIR-V opcodes handled through this path should set exact.
* Since it is ignored, assert on it.
*/
assert(!exact);
unsigned bit_size = glsl_get_bit_size(dst_type);
unsigned resolved_bit_size = dst_bit_size;
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < count - 4; i++) {
struct vtn_value *src_val =
vtn_value(b, w[4 + i], vtn_value_type_constant);
for (unsigned i = 0; i < src_count; i++) {
/* If this is an unsized source, pull the bit size from the
* source; otherwise, we'll use the bit size from the destination.
*/
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) {
if (org_src_type != src_type) {
/* Small float conversion. */
assert(i == 0);
bit_size = glsl_get_bit_size(src_type);
} else {
bit_size = glsl_get_bit_size(src_val->type->type);
}
resolved_bit_size = glsl_type_is_nonnative_float(src_type[i]) ?
32 : glsl_get_bit_size(src_type[i]);
}
unsigned src_comps = nir_op_infos[op].input_sizes[i] ?
@ -2928,12 +2933,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
unsigned j = swap ? 1 - i : i;
for (unsigned c = 0; c < src_comps; c++) {
src[j][c] = src_val->constant->values[c];
if (glsl_type_is_bfloat_16(org_src_type))
src[j][c] = src_val[i]->constant->values[c];
if (glsl_type_is_bfloat_16(src_type[i]))
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
else if (glsl_type_is_e4m3fn(org_src_type))
else if (glsl_type_is_e4m3fn(src_type[i]))
src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8);
else if (glsl_type_is_e5m2(org_src_type))
else if (glsl_type_is_e5m2(src_type[i]))
src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8);
}
@ -2945,7 +2950,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
/* Shift amount in NIR ops must be 32-bit. */
vtn_assert(!swap);
const unsigned shift_idx = 1;
const unsigned shift_bit_size = glsl_get_bit_size(src_val->type->type);
const unsigned shift_bit_size = glsl_get_bit_size(src_type[i]);
if (i != shift_idx || shift_bit_size == 32)
break;
for (unsigned c = 0; c < src_comps; c++) {
@ -2964,19 +2969,19 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
src[0], src[1], src[2],
};
nir_eval_const_opcode(op, val->constant->values, NULL,
num_components, bit_size, srcs,
num_components, resolved_bit_size, srcs,
b->shader->info.float_controls_execution_mode);
for (int i = 0; i < num_components; i++) {
uint16_t conv;
if (glsl_type_is_bfloat_16(org_dst_type)) {
if (glsl_type_is_bfloat_16(dst_type)) {
conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
} else if (glsl_type_is_e4m3fn(org_dst_type)) {
} else if (glsl_type_is_e4m3fn(dst_type)) {
if (saturate)
conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32);
else
conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32);
} else if (glsl_type_is_e5m2(org_dst_type)) {
} else if (glsl_type_is_e5m2(dst_type)) {
if (saturate)
conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32);
else
@ -2985,7 +2990,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
continue;
}
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type));
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(dst_type));
}
break;