mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-07 00:38:48 +02:00
spirv: Fix spec constant to handle Select for non-native floats
There was an assumption that if the instruction had non-native float as a source, the first source would have such type. This doesn't hold for Select, and the code failed in two ways - The boolean source of Select was being converted to the non-native float type. - The loop that resolves the bit-size for unsized operands would trip at `assert(i == 0)` because Select has more than one source. Re-organize the code to track the types of the sources independently, and fix both issues above. Fixes:90e1b12890("spirv: Add bfloat16 support to SpecConstantOp") Fixes:51d3c4c889("spirv: support float8 spec constant op") Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> (cherry picked from commit6affcb43a7) Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40359>
This commit is contained in:
parent
4588b025c8
commit
b2a34da82f
3 changed files with 48 additions and 35 deletions
|
|
@ -3014,7 +3014,7 @@
|
|||
"description": "spirv: Fix spec constant to handle Select for non-native floats",
|
||||
"nominated": true,
|
||||
"nomination_type": 2,
|
||||
"resolution": 0,
|
||||
"resolution": 1,
|
||||
"main_sha": null,
|
||||
"because_sha": "90e1b128903cabfe4fcfb5ae52cf46d5ddbf1189",
|
||||
"notes": null
|
||||
|
|
|
|||
|
|
@ -676,6 +676,14 @@ glsl_type_is_e5m2(const glsl_type *t)
|
|||
return t->base_type == GLSL_TYPE_FLOAT_E5M2;
|
||||
}
|
||||
|
||||
static inline bool
|
||||
glsl_type_is_nonnative_float(const glsl_type *t)
|
||||
{
|
||||
return t->base_type == GLSL_TYPE_BFLOAT16 ||
|
||||
t->base_type == GLSL_TYPE_FLOAT_E4M3FN ||
|
||||
t->base_type == GLSL_TYPE_FLOAT_E5M2;
|
||||
}
|
||||
|
||||
static inline bool
|
||||
glsl_type_is_int_16_32_64(const glsl_type *t)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2865,61 +2865,66 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
default: {
|
||||
bool swap;
|
||||
|
||||
const glsl_type *org_dst_type = val->type->type;
|
||||
const glsl_type *org_src_type = org_dst_type;
|
||||
const glsl_type *dst_type = val->type->type;
|
||||
|
||||
const bool saturate = vtn_has_decoration(b, val, SpvDecorationSaturatedToLargestFloat8NormalConversionEXT);
|
||||
unsigned num_components = glsl_get_vector_elements(val->type->type);
|
||||
|
||||
vtn_assert(count <= 7);
|
||||
|
||||
const unsigned src_count = count - 4;
|
||||
struct vtn_value *src_val[3] = {0};
|
||||
const glsl_type *src_type[3] = {0};
|
||||
|
||||
for (unsigned i = 0; i < src_count; i++) {
|
||||
src_val[i] = vtn_value(b, w[4 + i], vtn_value_type_constant);
|
||||
src_type[i] = src_val[i]->type->type;
|
||||
}
|
||||
|
||||
unsigned conv_src_bit_size;
|
||||
switch (opcode) {
|
||||
case SpvOpConvertFToU:
|
||||
case SpvOpConvertFToS:
|
||||
case SpvOpConvertSToF:
|
||||
case SpvOpConvertUToF:
|
||||
case SpvOpSConvert:
|
||||
case SpvOpFConvert:
|
||||
case SpvOpUConvert:
|
||||
/* We have a different source type in a conversion. */
|
||||
org_src_type = vtn_get_value_type(b, w[4])->type;
|
||||
conv_src_bit_size =
|
||||
glsl_type_is_nonnative_float(src_type[0]) ? 32 : glsl_get_bit_size(src_type[0]);
|
||||
break;
|
||||
default:
|
||||
/* When picking ALU ops, bit-size is only used for Convert
|
||||
* operations.
|
||||
*/
|
||||
conv_src_bit_size = 0;
|
||||
break;
|
||||
};
|
||||
|
||||
const glsl_type *dst_type = org_dst_type;
|
||||
if (glsl_type_is_bfloat_16(dst_type) || glsl_type_is_e4m3fn(dst_type) || glsl_type_is_e5m2(dst_type))
|
||||
dst_type = glsl_float_type();
|
||||
|
||||
const glsl_type *src_type = org_src_type;
|
||||
if (glsl_type_is_bfloat_16(src_type) || glsl_type_is_e4m3fn(src_type) || glsl_type_is_e5m2(src_type))
|
||||
src_type = glsl_float_type();
|
||||
const unsigned dst_bit_size =
|
||||
glsl_type_is_nonnative_float(dst_type) ? 32 : glsl_get_bit_size(dst_type);
|
||||
|
||||
bool exact;
|
||||
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
|
||||
glsl_get_bit_size(src_type),
|
||||
glsl_get_bit_size(dst_type));
|
||||
conv_src_bit_size, dst_bit_size);
|
||||
|
||||
/* No SPIR-V opcodes handled through this path should set exact.
|
||||
* Since it is ignored, assert on it.
|
||||
*/
|
||||
assert(!exact);
|
||||
|
||||
unsigned bit_size = glsl_get_bit_size(dst_type);
|
||||
unsigned resolved_bit_size = dst_bit_size;
|
||||
|
||||
nir_const_value src[3][NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (unsigned i = 0; i < count - 4; i++) {
|
||||
struct vtn_value *src_val =
|
||||
vtn_value(b, w[4 + i], vtn_value_type_constant);
|
||||
|
||||
for (unsigned i = 0; i < src_count; i++) {
|
||||
/* If this is an unsized source, pull the bit size from the
|
||||
* source; otherwise, we'll use the bit size from the destination.
|
||||
*/
|
||||
if (!nir_alu_type_get_type_size(nir_op_infos[op].input_types[i])) {
|
||||
if (org_src_type != src_type) {
|
||||
/* Small float conversion. */
|
||||
assert(i == 0);
|
||||
bit_size = glsl_get_bit_size(src_type);
|
||||
} else {
|
||||
bit_size = glsl_get_bit_size(src_val->type->type);
|
||||
}
|
||||
resolved_bit_size = glsl_type_is_nonnative_float(src_type[i]) ?
|
||||
32 : glsl_get_bit_size(src_type[i]);
|
||||
}
|
||||
|
||||
unsigned src_comps = nir_op_infos[op].input_sizes[i] ?
|
||||
|
|
@ -2928,12 +2933,12 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
|
||||
unsigned j = swap ? 1 - i : i;
|
||||
for (unsigned c = 0; c < src_comps; c++) {
|
||||
src[j][c] = src_val->constant->values[c];
|
||||
if (glsl_type_is_bfloat_16(org_src_type))
|
||||
src[j][c] = src_val[i]->constant->values[c];
|
||||
if (glsl_type_is_bfloat_16(src_type[i]))
|
||||
src[j][c].f32 = _mesa_bfloat16_bits_to_float(src[j][c].u16);
|
||||
else if (glsl_type_is_e4m3fn(org_src_type))
|
||||
else if (glsl_type_is_e4m3fn(src_type[i]))
|
||||
src[j][c].f32 = _mesa_e4m3fn_to_float(src[j][c].u8);
|
||||
else if (glsl_type_is_e5m2(org_src_type))
|
||||
else if (glsl_type_is_e5m2(src_type[i]))
|
||||
src[j][c].f32 = _mesa_e5m2_to_float(src[j][c].u8);
|
||||
}
|
||||
|
||||
|
|
@ -2945,7 +2950,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
/* Shift amount in NIR ops must be 32-bit. */
|
||||
vtn_assert(!swap);
|
||||
const unsigned shift_idx = 1;
|
||||
const unsigned shift_bit_size = glsl_get_bit_size(src_val->type->type);
|
||||
const unsigned shift_bit_size = glsl_get_bit_size(src_type[i]);
|
||||
if (i != shift_idx || shift_bit_size == 32)
|
||||
break;
|
||||
for (unsigned c = 0; c < src_comps; c++) {
|
||||
|
|
@ -2964,19 +2969,19 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
src[0], src[1], src[2],
|
||||
};
|
||||
nir_eval_const_opcode(op, val->constant->values, NULL,
|
||||
num_components, bit_size, srcs,
|
||||
num_components, resolved_bit_size, srcs,
|
||||
b->shader->info.float_controls_execution_mode);
|
||||
|
||||
for (int i = 0; i < num_components; i++) {
|
||||
uint16_t conv;
|
||||
if (glsl_type_is_bfloat_16(org_dst_type)) {
|
||||
if (glsl_type_is_bfloat_16(dst_type)) {
|
||||
conv = _mesa_float_to_bfloat16_bits_rte(val->constant->values[i].f32);
|
||||
} else if (glsl_type_is_e4m3fn(org_dst_type)) {
|
||||
} else if (glsl_type_is_e4m3fn(dst_type)) {
|
||||
if (saturate)
|
||||
conv = _mesa_float_to_e4m3fn_sat(val->constant->values[i].f32);
|
||||
else
|
||||
conv = _mesa_float_to_e4m3fn(val->constant->values[i].f32);
|
||||
} else if (glsl_type_is_e5m2(org_dst_type)) {
|
||||
} else if (glsl_type_is_e5m2(dst_type)) {
|
||||
if (saturate)
|
||||
conv = _mesa_float_to_e5m2_sat(val->constant->values[i].f32);
|
||||
else
|
||||
|
|
@ -2985,7 +2990,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
|
|||
continue;
|
||||
}
|
||||
|
||||
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(org_dst_type));
|
||||
val->constant->values[i] = nir_const_value_for_raw_uint(conv, glsl_get_bit_size(dst_type));
|
||||
}
|
||||
|
||||
break;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue