From b2a34da82f7af6be3c92d028c5853d78c5545f43 Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Sat, 28 Feb 2026 10:08:42 -0800 Subject: [PATCH] spirv: Fix spec constant to handle Select for non-native floats There was an assumption that if the instruction had non-native float as a source, the first source would have such type. This doesn't hold for Select, and the code failed in two ways - The boolean source of Select was being converted to the non-native float type. - The loop that resolves the bit-size for unsized operands would trip at `assert(i == 0)` because Select has more than one source. Re-organize the code to track the types of the sources independently, and fix both issues above. Fixes: 90e1b128903 ("spirv: Add bfloat16 support to SpecConstantOp") Fixes: 51d3c4c8896 ("spirv: support float8 spec constant op") Reviewed-by: Georg Lehmann (cherry picked from commit 6affcb43a7c765b19726ea82ec451a50714ba10e) Part-of: --- .pick_status.json | 2 +- src/compiler/glsl_types.h | 8 ++++ src/compiler/spirv/spirv_to_nir.c | 73 +++++++++++++++++-------------- 3 files changed, 48 insertions(+), 35 deletions(-) diff --git a/.pick_status.json b/.pick_status.json index fb69677b91d..1e5a42399fb 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -3014,7 +3014,7 @@ "description": "spirv: Fix spec constant to handle Select for non-native floats", "nominated": true, "nomination_type": 2, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "90e1b128903cabfe4fcfb5ae52cf46d5ddbf1189", "notes": null diff --git a/src/compiler/glsl_types.h b/src/compiler/glsl_types.h index 2d2f643f94c..3ec464464e0 100644 --- a/src/compiler/glsl_types.h +++ b/src/compiler/glsl_types.h @@ -676,6 +676,14 @@ glsl_type_is_e5m2(const glsl_type *t) return t->base_type == GLSL_TYPE_FLOAT_E5M2; } +static inline bool +glsl_type_is_nonnative_float(const glsl_type *t) +{ + return t->base_type == GLSL_TYPE_BFLOAT16 || + t->base_type == GLSL_TYPE_FLOAT_E4M3FN || + t->base_type == GLSL_TYPE_FLOAT_E5M2; +} + static inline bool glsl_type_is_int_16_32_64(const glsl_type *t) { diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 1b863640fea..f335a173258 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2865,61 +2865,66 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - const glsl_type *org_dst_type = val->type->type; - const glsl_type *org_src_type = org_dst_type; + const glsl_type *dst_type = val->type->type; const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT); unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); + const unsigned src_count = count - 4; + struct vtn_value *src_val[3] = {0}; + const glsl_type *src_type[3] = {0}; + + for (unsigned i = 0; i < src_count; i++) { + src_val[i] = vtn_value(b, w[4 + i], vtn_value_type_constant); + src_type[i] = src_val[i]->type->type; + } + + unsigned conv_src_bit_size; switch (opcode) { + case SpvOpConvertFToU: + case SpvOpConvertFToS: + case SpvOpConvertSToF: + case SpvOpConvertUToF: case SpvOpSConvert: case SpvOpFConvert: case SpvOpUConvert: /* We have a different source type in a conversion. */ - org_src_type = vtn_get_value_type(b, w[4])->type; + conv_src_bit_size = + glsl_type_is_nonnative_float(src_type[0]) ? 32 : glsl_get_bit_size(src_type[0]); break; default: + /* When picking ALU ops, bit-size is only used for Convert + * operations. + */ + conv_src_bit_size = 0; break; }; - const glsl_type *dst_type = org_dst_type; - if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type)) - dst_type = glsl_float_type(); - - const glsl_type *src_type = org_src_type; - if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type)) - src_type = glsl_float_type(); + const unsigned dst_bit_size = + glsl_type_is_nonnative_float(dst_type) ? 32 : glsl_get_bit_size(dst_type); bool exact; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, - glsl_get_bit_size(src_type), - glsl_get_bit_size(dst_type)); + conv_src_bit_size, dst_bit_size); /* No SPIR-V opcodes handled through this path should set exact. * Since it is ignored, assert on it. */ assert(!exact); - unsigned bit_size = glsl_get_bit_size(dst_type); + unsigned resolved_bit_size = dst_bit_size; + nir_const_value src[3][NIR_MAX_VEC_COMPONENTS]; - for (unsigned i = 0; i < count - 4; i++) { - struct vtn_value *src_val = - vtn_value(b, w[4 + i], vtn_value_type_constant); - + for (unsigned i = 0; i < src_count; i++) { /* If this is an unsized source, pull the bit size from the * source; otherwise, we'll use the bit size from the destination. */ if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) { - if (org_src_type != src_type) { - /* Small float conversion. */ - assert(i == 0); - bit_size = glsl_get_bit_size(src_type); - } else { - bit_size = glsl_get_bit_size(src_val->type->type); - } + resolved_bit_size = glsl_type_is_nonnative_float(src_type[i]) ? + 32 : glsl_get_bit_size(src_type[i]); } unsigned src_comps = nir_op_infos[op].input_sizes[i] ? @@ -2928,12 +2933,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, unsigned j = swap ? 1 - i : i; for (unsigned c = 0; c < src_comps; c++) { - src[j][c] = src_val->constant->values[c]; - if (glsl_type_is_bfloat_16(org_src_type)) + src[j][c] = src_val[i]->constant->values[c]; + if (glsl_type_is_bfloat_16(src_type[i])) src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16); - else if (glsl_type_is_e4m3fn(org_src_type)) + else if (glsl_type_is_e4m3fn(src_type[i])) src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8); - else if (glsl_type_is_e5m2(org_src_type)) + else if (glsl_type_is_e5m2(src_type[i])) src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8); } @@ -2945,7 +2950,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, /* Shift amount in NIR ops must be 32-bit. */ vtn_assert(!swap); const unsigned shift_idx = 1; - const unsigned shift_bit_size = glsl_get_bit_size(src_val->type->type); + const unsigned shift_bit_size = glsl_get_bit_size(src_type[i]); if (i != shift_idx || shift_bit_size == 32) break; for (unsigned c = 0; c < src_comps; c++) { @@ -2964,19 +2969,19 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, src[0], src[1], src[2], }; nir_eval_const_opcode(op, val->constant->values, NULL, - num_components, bit_size, srcs, + num_components, resolved_bit_size, srcs, b->shader->info.float_controls_execution_mode); for (int i = 0; i < num_components; i++) { uint16_t conv; - if (glsl_type_is_bfloat_16(org_dst_type)) { + if (glsl_type_is_bfloat_16(dst_type)) { conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32); - } else if (glsl_type_is_e4m3fn(org_dst_type)) { + } else if (glsl_type_is_e4m3fn(dst_type)) { if (saturate) conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32); else conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32); - } else if (glsl_type_is_e5m2(org_dst_type)) { + } else if (glsl_type_is_e5m2(dst_type)) { if (saturate) conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32); else @@ -2985,7 +2990,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, continue; } - val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type)); + val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(dst_type)); } break;