diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 8f8a355dbda..0f039626d7c 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -402,6 +402,25 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, } } +void +vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val) +{ + /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them + * on the builder, so the generated instructions can take it from it. + * We only care about some of them, check nir_alu_instr for details. + * We also copy all bit widths, because we can't easily get the correct one + * here. + */ +#define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \ + FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \ + FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64) + static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9), + "enum float_controls and fp_fast_math out of sync!"); + b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode & + FLOAT_CONTROLS2_BITS; +#undef FLOAT_CONTROLS2_BITS +} + static void handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val, UNUSED int member, const struct vtn_decoration *dec, @@ -581,6 +600,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, } vtn_handle_no_contraction(b, dest_val); + vtn_handle_fp_fast_math(b, dest_val); bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val); /* Collect the various SSA sources */ diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index c7d90021e1f..df529265c25 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -697,6 +697,7 @@ bool vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode, const uint32_t *w, unsigned count) { + vtn_handle_fp_fast_math(b, vtn_untyped_value(b, w[2])); switch ((enum GLSLstd450)ext_opcode) { case GLSLstd450Determinant: { vtn_push_nir_ssa(b, w[2], build_mat_det(b, vtn_ssa_value(b, w[5]))); diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index d703224537d..dcb905dc561 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -971,6 +971,8 @@ void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, void vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val); +void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val); + void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count);