From 6affcb43a7c765b19726ea82ec451a50714ba10e Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Sat, 28 Feb 2026 10:08:42 -0800 Subject: [PATCH] spirv: Fix spec constant to handle Select for non-native floats There was an assumption that if the instruction had non-native float as a source, the first source would have such type. This doesn't hold for Select, and the code failed in two ways - The boolean source of Select was being converted to the non-native float type. - The loop that resolves the bit-size for unsized operands would trip at `assert(i == 0)` because Select has more than one source. Re-organize the code to track the types of the sources independently, and fix both issues above. Fixes: 90e1b128903 ("spirv: Add bfloat16 support to SpecConstantOp") Fixes: 51d3c4c8896 ("spirv: support float8 spec constant op") Reviewed-by: Georg Lehmann Part-of: --- src/compiler/glsl_types.h | 8 ++++ src/compiler/spirv/spirv_to_nir.c | 73 +++++++++++++++++-------------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/compiler/glsl_types.h b/src/compiler/glsl_types.h index 2d2f643f94c..3ec464464e0 100644 --- a/src/compiler/glsl_types.h +++ b/src/compiler/glsl_types.h @@ -676,6 +676,14 @@ glsl_type_is_e5m2(const glsl_type *t) return t->base_type == GLSL_TYPE_FLOAT_E5M2; } +static inline bool +glsl_type_is_nonnative_float(const glsl_type *t) +{ + return t->base_type == GLSL_TYPE_BFLOAT16 || + t->base_type == GLSL_TYPE_FLOAT_E4M3FN || + t->base_type == GLSL_TYPE_FLOAT_E5M2; +} + static inline bool glsl_type_is_int_16_32_64(const glsl_type *t) { diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f704da75c18..2806ea1fd82 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2873,61 +2873,66 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - const glsl_type *org_dst_type = val->type->type; - const glsl_type *org_src_type = org_dst_type; + const glsl_type *dst_type = val->type->type; const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT); unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); + const unsigned src_count = count - 4; + struct vtn_value *src_val[3] = {0}; + const glsl_type *src_type[3] = {0}; + + for (unsigned i = 0; i < src_count; i++) { + src_val[i] = vtn_value(b, w[4 + i], vtn_value_type_constant); + src_type[i] = src_val[i]->type->type; + } + + unsigned conv_src_bit_size; switch (opcode) { + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: case SpvOpSConvert: case SpvOpFConvert: case SpvOpUConvert: /* We have a different source type in a conversion. */ - org_src_type = vtn_get_value_type(b, w[4])->type; + conv_src_bit_size = + glsl_type_is_nonnative_float(src_type[0]) ? 32 : glsl_get_bit_size(src_type[0]); break; default: + /* When picking ALU ops, bit-size is only used for Convert + * operations. + */ + conv_src_bit_size = 0; break; }; - const glsl_type *dst_type = org_dst_type; - if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type)) - dst_type = glsl_float_type(); - - const glsl_type *src_type = org_src_type; - if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type)) - src_type = glsl_float_type(); + const unsigned dst_bit_size = + glsl_type_is_nonnative_float(dst_type) ? 32 : glsl_get_bit_size(dst_type); unsigned extra_fp_math_ctrl; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &extra_fp_math_ctrl, - glsl_get_bit_size(src_type), - glsl_get_bit_size(dst_type)); + conv_src_bit_size, dst_bit_size); /* No SPIR-V opcodes handled through this path should set fast math. * Since it is ignored, assert on it. */ assert(!extra_fp_math_ctrl); - unsigned bit_size = glsl_get_bit_size(dst_type); + unsigned resolved_bit_size = dst_bit_size; + nir_const_value src[3][NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i < count - 4; i++) { - struct vtn_value *src_val = - vtn_value(b, w[4 + i], vtn_value_type_constant); - + for (unsigned i = 0; i < src_count; i++) { /* If this is an unsized source, pull the bit size from the * source; otherwise, we'll use the bit size from the destination. */ if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) { - if (org_src_type != src_type) { - /* Small float conversion. */ - assert(i == 0); - bit_size = glsl_get_bit_size(src_type); - } else { - bit_size = glsl_get_bit_size(src_val->type->type); - } + resolved_bit_size = glsl_type_is_nonnative_float(src_type[i]) ? + 32 : glsl_get_bit_size(src_type[i]); } unsigned src_comps = nir_op_infos[op].input_sizes[i] ? @@ -2936,12 +2941,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, unsigned j = swap ? 1 - i : i; for (unsigned c = 0; c < src_comps; c++) { - src[j][c] = src_val->constant->values[c]; - if (glsl_type_is_bfloat_16(org_src_type)) + src[j][c] = src_val[i]->constant->values[c]; + if (glsl_type_is_bfloat_16(src_type[i])) src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16); - else if (glsl_type_is_e4m3fn(org_src_type)) + else if (glsl_type_is_e4m3fn(src_type[i])) src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8); - else if (glsl_type_is_e5m2(org_src_type)) + else if (glsl_type_is_e5m2(src_type[i])) src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8); } @@ -2953,7 +2958,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, /* Shift amount in NIR ops must be 32-bit. */ vtn_assert(!swap); const unsigned shift_idx = 1; - const unsigned shift_bit_size = glsl_get_bit_size(src_val->type->type); + const unsigned shift_bit_size = glsl_get_bit_size(src_type[i]); if (i != shift_idx || shift_bit_size == 32) break; for (unsigned c = 0; c < src_comps; c++) { @@ -2972,19 +2977,19 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, src[0], src[1], src[2], }; nir_eval_const_opcode(op, val->constant->values, NULL, - num_components, bit_size, srcs, + num_components, resolved_bit_size, srcs, b->shader->info.float_controls_execution_mode); for (int i = 0; i < num_components; i++) { uint16_t conv; - if (glsl_type_is_bfloat_16(org_dst_type)) { + if (glsl_type_is_bfloat_16(dst_type)) { conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); - } else if (glsl_type_is_e4m3fn(org_dst_type)) { + } else if (glsl_type_is_e4m3fn(dst_type)) { if (saturate) conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32); else conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32); - } else if (glsl_type_is_e5m2(org_dst_type)) { + } else if (glsl_type_is_e5m2(dst_type)) { if (saturate) conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32); else @@ -2993,7 +2998,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, continue; } - val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type)); + val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(dst_type)); } break;