From 7c217e540cb174eac99434e2baf0da108be68f28 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Sat, 7 Feb 2026 21:40:31 +0100 Subject: [PATCH] nir: add a pass to optimize fp_math_ctrl Reviewed-by: Rhys Perry Part-of: --- src/compiler/nir/meson.build | 1 + src/compiler/nir/nir.h | 2 + src/compiler/nir/nir_opt_fp_math_ctrl.c | 302 ++++++++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100644 src/compiler/nir/nir_opt_fp_math_ctrl.c diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index 9c011260c3f..8047cab8408 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -268,6 +268,7 @@ else 'nir_opt_dead_cf.c', 'nir_opt_dead_write_vars.c', 'nir_opt_find_array_copies.c', + 'nir_opt_fp_math_ctrl.c', 'nir_opt_frag_coord_to_pixel_coord.c', 'nir_opt_fragdepth.c', 'nir_opt_gcm.c', diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 092869b905e..fb061c88d6d 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6502,6 +6502,8 @@ bool nir_opt_algebraic_integer_promotion(nir_shader *shader); bool nir_opt_reassociate_matrix_mul(nir_shader *shader); bool nir_opt_constant_folding(nir_shader *shader); +bool nir_opt_fp_math_ctrl(nir_shader *shader); + nir_def *nir_try_constant_fold_alu(nir_builder *b, nir_alu_instr *alu); /* Try to combine a and b into a. Return true if combination was possible, diff --git a/src/compiler/nir/nir_opt_fp_math_ctrl.c b/src/compiler/nir/nir_opt_fp_math_ctrl.c new file mode 100644 index 00000000000..8dc265eeb77 --- /dev/null +++ b/src/compiler/nir/nir_opt_fp_math_ctrl.c @@ -0,0 +1,302 @@ +/* + * Copyright 2026 Valve Corporation + * SPDX-License-Identifier: MIT + */ + +#include "nir.h" +#include "nir_range_analysis.h" + +/* Applications like DXVK spam float control restriction flags on all ALU instructions. + * Remove signed zero, Inf, NaN preserve flags if we can prove that no inputs/outputs + * are zero/inf/NaN using range analysis. + * For signed zero, we can go a step further by back propagating when signed zero are not + * needed, which is a quite common. For example, any float comparison, cosinus, exp2, log2, + * or addition with non zero value does not care about the zero sign of the inputs. Neither + * do texture coordinates. + * + * Future work could also consider fragment output state, fixed point or R11G11B10 formats + * do not care about the sign of zero. + * For pre raster stages, position doesn't care, and we could back propagate information from + * the FS for varyings, and interpolated varyings do not care anyway. + */ + +struct opt_fp_ctrl_state { + nir_fp_analysis_state fp_class_state; +}; + +static inline nir_block * +block_get_loop_preheader(nir_block *block) +{ + nir_cf_node *parent = block->cf_node.parent; + if (parent->type != nir_cf_node_loop) + return NULL; + if (block != nir_cf_node_cf_tree_first(parent)) + return NULL; + return nir_cf_node_as_block(nir_cf_node_prev(parent)); +} + + +static bool +src_mark_preserve_sz(nir_src *src, UNUSED void *state) +{ + nir_def_instr(src->ssa)->pass_flags = true; + return true; +} + +static bool +can_prop_nsz(nir_alu_instr *alu) +{ + /* Only divide cares about the sign of zero even when the sign of zero + * of the output doesn't matter. + */ + switch (alu->op) { + case nir_op_fdiv: + case nir_op_frcp: + case nir_op_frsq: + case nir_op_fcopysign_pco: + return false; + default: + return true; + } +} + +static bool +opt_alu_fp_math_ctrl(nir_alu_instr *alu, struct opt_fp_ctrl_state *state) +{ + if (alu->op == nir_op_bcsel) { + src_mark_preserve_sz(&alu->src[0].src, NULL); + + if (alu->instr.pass_flags) { + src_mark_preserve_sz(&alu->src[1].src, NULL); + src_mark_preserve_sz(&alu->src[2].src, NULL); + } + return false; + } else if (nir_op_is_vec_or_mov(alu->op) && !alu->instr.pass_flags) { + return false; + } + + const nir_op_info *op_info = &nir_op_infos[(int)alu->op]; + unsigned old_fp_math_ctrl = alu->fp_math_ctrl; + if (alu->fp_math_ctrl & nir_fp_preserve_sz_inf_nan) { + fp_class_mask class_mask = 0; + + bool dest_is_float = nir_alu_type_get_base_type(op_info->output_type) == nir_type_float; + + if (dest_is_float) { + class_mask |= nir_analyze_fp_class(&state->fp_class_state, &alu->def); + if (can_prop_nsz(alu) && (!alu->instr.pass_flags || !(class_mask & FP_CLASS_ANY_ZERO))) + alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero; + } + + for (unsigned i = 0; i < op_info->num_inputs; i++) { + if (nir_alu_type_get_base_type(op_info->input_types[i]) == nir_type_float) + class_mask |= nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa); + } + + /* If class_mask is 0, the opcode has no float operands/definition, + * So it must be a special opcode, or operate on bfloats. Ignore these. + */ + if (class_mask) { + /* If none of the float operands or the definition can be zero/Inf/NaN, + * remove the matching fp_math_ctrl flag. + */ + if (!(class_mask & FP_CLASS_ANY_ZERO)) + alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero; + if (!(class_mask & FP_CLASS_ANY_INF)) + alu->fp_math_ctrl &= ~nir_fp_preserve_inf; + if (!(class_mask & FP_CLASS_NAN)) + alu->fp_math_ctrl &= ~nir_fp_preserve_nan; + } + } + + if (alu->fp_math_ctrl & nir_fp_preserve_signed_zero) { + /* Some alu never cares about the input sign of zero. */ + switch (alu->op) { + case nir_op_fabs: + case nir_op_fsat: + case nir_op_fexp2: + case nir_op_flog2: + case nir_op_fcos: + case nir_op_fcos_amd: + case nir_op_fmulz: + break; + case nir_op_fmin: { + bool had_neg_zero = false; + for (unsigned i = 0; i < 2; i++) { + fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa); + + if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_POS | FP_CLASS_POS_ZERO)) { + src_mark_preserve_sz(&alu->src[!i].src, NULL); + } else if (fp_class & FP_CLASS_NEG_ZERO) { + /* If both operands can be -0.0, at least one needs to be preserved. */ + if (had_neg_zero) + src_mark_preserve_sz(&alu->src[!i].src, NULL); + had_neg_zero = true; + } + } + break; + } + + case nir_op_fmax: { + bool had_pos_zero = false; + for (unsigned i = 0; i < 2; i++) { + fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa); + + if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_NEG | FP_CLASS_NEG_ZERO)) { + src_mark_preserve_sz(&alu->src[!i].src, NULL); + } else if (fp_class & FP_CLASS_POS_ZERO) { + /* If both operands can be +0.0, at least one needs to be preserved. */ + if (had_pos_zero) + src_mark_preserve_sz(&alu->src[!i].src, NULL); + had_pos_zero = true; + } + } + break; + } + case nir_op_fsub: + case nir_op_fadd: { + bool had_pos_zero = false; + for (unsigned i = 0; i < 2; i++) { + fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa); + + bool negate = i == 1 && alu->op == nir_op_fsub; + + if (fp_class & (negate ? FP_CLASS_POS_ZERO : FP_CLASS_NEG_ZERO)) { + src_mark_preserve_sz(&alu->src[!i].src, NULL); + } else if (fp_class & (negate ? FP_CLASS_NEG_ZERO : FP_CLASS_POS_ZERO)) { + /* If both operands can be +0.0, at least one needs to be preserved. */ + if (had_pos_zero) + src_mark_preserve_sz(&alu->src[!i].src, NULL); + had_pos_zero = true; + } + } + break; + } + + case nir_op_ffmaz: + src_mark_preserve_sz(&alu->src[2].src, NULL); + break; + case nir_op_ffma: + if ((nir_analyze_fp_class(&state->fp_class_state, alu->src[2].src.ssa) & FP_CLASS_NEG_ZERO) && + !nir_alu_srcs_equal(alu, alu, 0, 1)) { + src_mark_preserve_sz(&alu->src[0].src, NULL); + src_mark_preserve_sz(&alu->src[1].src, NULL); + } + src_mark_preserve_sz(&alu->src[2].src, NULL); + break; + case nir_op_fmul: + if (nir_alu_srcs_equal(alu, alu, 0, 1)) + break; + FALLTHROUGH; + default: + for (unsigned i = 0; i < op_info->num_inputs; i++) + src_mark_preserve_sz(&alu->src[i].src, NULL); + break; + } + } else { + /* Only preserve signed zeros for non float operands. */ + for (unsigned i = 0; i < op_info->num_inputs; i++) { + if (nir_alu_type_get_base_type(op_info->input_types[i]) != nir_type_float) + src_mark_preserve_sz(&alu->src[i].src, NULL); + } + } + + return alu->fp_math_ctrl != old_fp_math_ctrl; +} + +static void +prop_tex_fp_math_ctrl(nir_tex_instr *tex) +{ + for (unsigned i = 0; i < tex->num_srcs; i++) { + /* Floating point tex sources don't care about sign of zero. */ + if (nir_tex_instr_src_type(tex, i) != nir_type_float) + src_mark_preserve_sz(&tex->src[i].src, NULL); + } +} + +static void +prop_intrin_fp_math_ctrl(nir_intrinsic_instr *intrin) +{ + switch (intrin->intrinsic) { + case nir_intrinsic_ddx: + case nir_intrinsic_ddx_coarse: + case nir_intrinsic_ddx_fine: + case nir_intrinsic_ddy: + case nir_intrinsic_ddy_coarse: + case nir_intrinsic_ddy_fine: + if (intrin->instr.pass_flags) + src_mark_preserve_sz(&intrin->src[0], NULL); + break; + default: + nir_foreach_src(&intrin->instr, src_mark_preserve_sz, NULL); + break; + } +} + +static bool +opt_fp_math_ctrl_impl(nir_function_impl *impl) +{ + /* Setup pass flags: Store if signed zeros are needed. + * Handle loop header phis here already: + * For now, just disable opts for the back edges. + * That could be improved, but not sure if it's worth it. + */ + nir_foreach_block_reverse(block, impl) { + nir_block *preheader = block_get_loop_preheader(block); + + nir_foreach_instr_reverse(instr, block) { + instr->pass_flags = false; + if (instr->type == nir_instr_type_phi && preheader) { + nir_phi_instr *phi = nir_instr_as_phi(instr); + nir_foreach_phi_src(src, phi) { + if (src->pred != preheader) + nir_def_instr(src->src.ssa)->pass_flags = true; + } + } + } + } + + + struct opt_fp_ctrl_state state = {0}; + state.fp_class_state = nir_create_fp_analysis_state(impl); + + bool progress = false; + + nir_foreach_block_reverse(block, impl) { + nir_foreach_instr_reverse(instr, block) { + switch (instr->type) { + case nir_instr_type_alu: + progress |= opt_alu_fp_math_ctrl(nir_instr_as_alu(instr), &state); + break; + case nir_instr_type_tex: + prop_tex_fp_math_ctrl(nir_instr_as_tex(instr)); + break; + case nir_instr_type_intrinsic: + prop_intrin_fp_math_ctrl(nir_instr_as_intrinsic(instr)); + break; + case nir_instr_type_phi: + if (!instr->pass_flags) + break; + FALLTHROUGH; + default: + nir_foreach_src(instr, src_mark_preserve_sz, NULL); + break; + } + } + } + + nir_free_fp_analysis_state(&state.fp_class_state); + + return nir_progress(progress, impl, nir_metadata_all); +} + +bool +nir_opt_fp_math_ctrl(nir_shader *shader) +{ + bool progress = false; + + nir_foreach_function_impl(impl, shader) + progress |= opt_fp_math_ctrl_impl(impl); + + return progress; +}