diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index d3b2c4c3b19..fd55c06a81b 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -5734,9 +5734,15 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, b->shader->info.fs.sample_interlock_unordered = true; break; + case SpvExecutionModeSignedZeroInfNanPreserve: { + const struct glsl_type *type = glsl_floatN_t_type(mode->operands[0]); + unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, glsl_get_base_type(type)); + *fp_math_ctrl |= nir_fp_preserve_sz_inf_nan; + break; + } + case SpvExecutionModeDenormPreserve: case SpvExecutionModeDenormFlushToZero: - case SpvExecutionModeSignedZeroInfNanPreserve: case SpvExecutionModeRoundingModeRTE: case SpvExecutionModeRoundingModeRTZ: { unsigned execution_mode = 0; @@ -5757,14 +5763,6 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, default: vtn_fail("Floating point type not supported"); } break; - case SpvExecutionModeSignedZeroInfNanPreserve: - switch (mode->operands[0]) { - case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_sz_inf_nan; break; - case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_sz_inf_nan; break; - case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_sz_inf_nan; break; - default: vtn_fail("Floating point type not supported"); - } - break; case SpvExecutionModeRoundingModeRTE: switch (mode->operands[0]) { case 16: execution_mode = FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16; break; @@ -5925,6 +5923,12 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin struct vtn_type *type = vtn_get_type(b, mode->operands[0]); SpvFPFastMathModeMask flags = vtn_constant_uint(b, mode->operands[1]); + enum glsl_base_type base_type = glsl_get_base_type(type->type); + unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type); + + if (!fp_math_ctrl) + vtn_fail("Unkown float type for FPFastMathDefault"); + SpvFPFastMathModeMask can_fast_math = SpvFPFastMathModeAllowRecipMask | SpvFPFastMathModeAllowContractMask | @@ -5933,27 +5937,15 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin if ((flags & can_fast_math) != can_fast_math) b->exact = true; - if (!(flags & SpvFPFastMathModeNotNaNMask)) { - switch (glsl_get_bit_size(type->type)) { - case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_nan; break; - case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_nan; break; - case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_nan; break; - } - } - if (!(flags & SpvFPFastMathModeNotInfMask)) { - switch (glsl_get_bit_size(type->type)) { - case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_inf; break; - case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_inf; break; - case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_inf; break; - } - } - if (!(flags & SpvFPFastMathModeNSZMask)) { - switch (glsl_get_bit_size(type->type)) { - case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_signed_zero; break; - case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_signed_zero; break; - case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_signed_zero; break; - } - } + if (!(flags & SpvFPFastMathModeNotNaNMask)) + *fp_math_ctrl |= nir_fp_preserve_nan; + + if (!(flags & SpvFPFastMathModeNotInfMask)) + *fp_math_ctrl |= nir_fp_preserve_inf; + + if (!(flags & SpvFPFastMathModeNSZMask)) + *fp_math_ctrl |= nir_fp_preserve_signed_zero; + break; } diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 8d2256055f5..73c1bc0a6c7 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -425,6 +425,39 @@ handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val, b->nb.fp_math_ctrl |= nir_fp_preserve_inf; } +unsigned * +vtn_fp_math_ctrl_for_base_type(struct vtn_builder *b, enum glsl_base_type base_type) +{ + switch (base_type) { + case GLSL_TYPE_FLOAT16: return &b->fp_math_ctrl[0]; + case GLSL_TYPE_FLOAT: return &b->fp_math_ctrl[1]; + case GLSL_TYPE_DOUBLE: return &b->fp_math_ctrl[2]; + case GLSL_TYPE_BFLOAT16: return &b->fp_math_ctrl[3]; + case GLSL_TYPE_FLOAT_E4M3FN: return &b->fp_math_ctrl[4]; + case GLSL_TYPE_FLOAT_E5M2: return &b->fp_math_ctrl[5]; + default: return NULL; + } +} + +static unsigned +fp_math_ctrl_for_type(struct vtn_builder *b, struct vtn_type *type) +{ + if (!type) + return nir_fp_fast_math; + + enum glsl_base_type base_type; + + /* Some ALU like modf and frexp return a struct of two values. */ + if (glsl_type_is_struct(type->type)) + base_type = glsl_get_base_type(type->type->fields.structure[0].type); + else + base_type = glsl_get_base_type(type->type); + + unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type); + + return fp_math_ctrl ? *fp_math_ctrl : nir_fp_fast_math; +} + void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val) { @@ -432,23 +465,8 @@ vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val) * on the builder, so the generated instructions can take it from it. * We only care about some of them, check nir_alu_instr for details. */ - unsigned bit_size; - /* Some ALU like modf and frexp return a struct of two values. */ - if (!val->type) - bit_size = 0; - else if (glsl_type_is_struct(val->type->type)) - bit_size = glsl_get_bit_size(val->type->type->fields.structure[0].type); - else - bit_size = glsl_get_bit_size(val->type->type); - - - switch (bit_size) { - case 16: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp16; break; - case 32: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp32; break; - case 64: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp64; break; - default: b->nb.fp_math_ctrl = 0; break; - } + b->nb.fp_math_ctrl = fp_math_ctrl_for_type(b, val->type); vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL); diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 9aa6cb4c912..35ad4b0ad55 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -710,9 +710,7 @@ struct vtn_builder { /* false by default, set to true by the ContractionOff execution mode */ bool exact; - unsigned fp_math_ctrl_fp16; - unsigned fp_math_ctrl_fp32; - unsigned fp_math_ctrl_fp64; + unsigned fp_math_ctrl[6]; /* when a physical memory model is choosen */ bool physical_ptrs; @@ -992,6 +990,8 @@ void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count); +unsigned *vtn_fp_math_ctrl_for_base_type(struct vtn_builder *b, enum glsl_base_type base_type); + void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val); void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,