nir: add a pass to optimize fp_math_ctrl

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40098>
This commit is contained in:
Georg Lehmann 2026-02-07 21:40:31 +01:00
parent 042ee8dafc
commit 7c217e540c
3 changed files with 305 additions and 0 deletions

View file

@ -268,6 +268,7 @@ else
'nir_opt_dead_cf.c',
'nir_opt_dead_write_vars.c',
'nir_opt_find_array_copies.c',
'nir_opt_fp_math_ctrl.c',
'nir_opt_frag_coord_to_pixel_coord.c',
'nir_opt_fragdepth.c',
'nir_opt_gcm.c',

View file

@ -6502,6 +6502,8 @@ bool nir_opt_algebraic_integer_promotion(nir_shader *shader);
bool nir_opt_reassociate_matrix_mul(nir_shader *shader);
bool nir_opt_constant_folding(nir_shader *shader);
bool nir_opt_fp_math_ctrl(nir_shader *shader);
nir_def *nir_try_constant_fold_alu(nir_builder *b, nir_alu_instr *alu);
/* Try to combine a and b into a. Return true if combination was possible,

View file

@ -0,0 +1,302 @@
/*
* Copyright 2026 Valve Corporation
* SPDX-License-Identifier: MIT
*/
#include "nir.h"
#include "nir_range_analysis.h"
/* Applications like DXVK spam float control restriction flags on all ALU instructions.
* Remove signed zero, Inf, NaN preserve flags if we can prove that no inputs/outputs
* are zero/inf/NaN using range analysis.
* For signed zero, we can go a step further by back propagating when signed zero are not
* needed, which is a quite common. For example, any float comparison, cosinus, exp2, log2,
* or addition with non zero value does not care about the zero sign of the inputs. Neither
* do texture coordinates.
*
* Future work could also consider fragment output state, fixed point or R11G11B10 formats
* do not care about the sign of zero.
* For pre raster stages, position doesn't care, and we could back propagate information from
* the FS for varyings, and interpolated varyings do not care anyway.
*/
struct opt_fp_ctrl_state {
nir_fp_analysis_state fp_class_state;
};
static inline nir_block *
block_get_loop_preheader(nir_block *block)
{
nir_cf_node *parent = block->cf_node.parent;
if (parent->type != nir_cf_node_loop)
return NULL;
if (block != nir_cf_node_cf_tree_first(parent))
return NULL;
return nir_cf_node_as_block(nir_cf_node_prev(parent));
}
static bool
src_mark_preserve_sz(nir_src *src, UNUSED void *state)
{
nir_def_instr(src->ssa)->pass_flags = true;
return true;
}
static bool
can_prop_nsz(nir_alu_instr *alu)
{
/* Only divide cares about the sign of zero even when the sign of zero
* of the output doesn't matter.
*/
switch (alu->op) {
case nir_op_fdiv:
case nir_op_frcp:
case nir_op_frsq:
case nir_op_fcopysign_pco:
return false;
default:
return true;
}
}
static bool
opt_alu_fp_math_ctrl(nir_alu_instr *alu, struct opt_fp_ctrl_state *state)
{
if (alu->op == nir_op_bcsel) {
src_mark_preserve_sz(&alu->src[0].src, NULL);
if (alu->instr.pass_flags) {
src_mark_preserve_sz(&alu->src[1].src, NULL);
src_mark_preserve_sz(&alu->src[2].src, NULL);
}
return false;
} else if (nir_op_is_vec_or_mov(alu->op) && !alu->instr.pass_flags) {
return false;
}
const nir_op_info *op_info = &nir_op_infos[(int)alu->op];
unsigned old_fp_math_ctrl = alu->fp_math_ctrl;
if (alu->fp_math_ctrl & nir_fp_preserve_sz_inf_nan) {
fp_class_mask class_mask = 0;
bool dest_is_float = nir_alu_type_get_base_type(op_info->output_type) == nir_type_float;
if (dest_is_float) {
class_mask |= nir_analyze_fp_class(&state->fp_class_state, &alu->def);
if (can_prop_nsz(alu) && (!alu->instr.pass_flags || !(class_mask & FP_CLASS_ANY_ZERO)))
alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero;
}
for (unsigned i = 0; i < op_info->num_inputs; i++) {
if (nir_alu_type_get_base_type(op_info->input_types[i]) == nir_type_float)
class_mask |= nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
}
/* If class_mask is 0, the opcode has no float operands/definition,
* So it must be a special opcode, or operate on bfloats. Ignore these.
*/
if (class_mask) {
/* If none of the float operands or the definition can be zero/Inf/NaN,
* remove the matching fp_math_ctrl flag.
*/
if (!(class_mask & FP_CLASS_ANY_ZERO))
alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero;
if (!(class_mask & FP_CLASS_ANY_INF))
alu->fp_math_ctrl &= ~nir_fp_preserve_inf;
if (!(class_mask & FP_CLASS_NAN))
alu->fp_math_ctrl &= ~nir_fp_preserve_nan;
}
}
if (alu->fp_math_ctrl & nir_fp_preserve_signed_zero) {
/* Some alu never cares about the input sign of zero. */
switch (alu->op) {
case nir_op_fabs:
case nir_op_fsat:
case nir_op_fexp2:
case nir_op_flog2:
case nir_op_fcos:
case nir_op_fcos_amd:
case nir_op_fmulz:
break;
case nir_op_fmin: {
bool had_neg_zero = false;
for (unsigned i = 0; i < 2; i++) {
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_POS | FP_CLASS_POS_ZERO)) {
src_mark_preserve_sz(&alu->src[!i].src, NULL);
} else if (fp_class & FP_CLASS_NEG_ZERO) {
/* If both operands can be -0.0, at least one needs to be preserved. */
if (had_neg_zero)
src_mark_preserve_sz(&alu->src[!i].src, NULL);
had_neg_zero = true;
}
}
break;
}
case nir_op_fmax: {
bool had_pos_zero = false;
for (unsigned i = 0; i < 2; i++) {
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_NEG | FP_CLASS_NEG_ZERO)) {
src_mark_preserve_sz(&alu->src[!i].src, NULL);
} else if (fp_class & FP_CLASS_POS_ZERO) {
/* If both operands can be +0.0, at least one needs to be preserved. */
if (had_pos_zero)
src_mark_preserve_sz(&alu->src[!i].src, NULL);
had_pos_zero = true;
}
}
break;
}
case nir_op_fsub:
case nir_op_fadd: {
bool had_pos_zero = false;
for (unsigned i = 0; i < 2; i++) {
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
bool negate = i == 1 && alu->op == nir_op_fsub;
if (fp_class & (negate ? FP_CLASS_POS_ZERO : FP_CLASS_NEG_ZERO)) {
src_mark_preserve_sz(&alu->src[!i].src, NULL);
} else if (fp_class & (negate ? FP_CLASS_NEG_ZERO : FP_CLASS_POS_ZERO)) {
/* If both operands can be +0.0, at least one needs to be preserved. */
if (had_pos_zero)
src_mark_preserve_sz(&alu->src[!i].src, NULL);
had_pos_zero = true;
}
}
break;
}
case nir_op_ffmaz:
src_mark_preserve_sz(&alu->src[2].src, NULL);
break;
case nir_op_ffma:
if ((nir_analyze_fp_class(&state->fp_class_state, alu->src[2].src.ssa) & FP_CLASS_NEG_ZERO) &&
!nir_alu_srcs_equal(alu, alu, 0, 1)) {
src_mark_preserve_sz(&alu->src[0].src, NULL);
src_mark_preserve_sz(&alu->src[1].src, NULL);
}
src_mark_preserve_sz(&alu->src[2].src, NULL);
break;
case nir_op_fmul:
if (nir_alu_srcs_equal(alu, alu, 0, 1))
break;
FALLTHROUGH;
default:
for (unsigned i = 0; i < op_info->num_inputs; i++)
src_mark_preserve_sz(&alu->src[i].src, NULL);
break;
}
} else {
/* Only preserve signed zeros for non float operands. */
for (unsigned i = 0; i < op_info->num_inputs; i++) {
if (nir_alu_type_get_base_type(op_info->input_types[i]) != nir_type_float)
src_mark_preserve_sz(&alu->src[i].src, NULL);
}
}
return alu->fp_math_ctrl != old_fp_math_ctrl;
}
static void
prop_tex_fp_math_ctrl(nir_tex_instr *tex)
{
for (unsigned i = 0; i < tex->num_srcs; i++) {
/* Floating point tex sources don't care about sign of zero. */
if (nir_tex_instr_src_type(tex, i) != nir_type_float)
src_mark_preserve_sz(&tex->src[i].src, NULL);
}
}
static void
prop_intrin_fp_math_ctrl(nir_intrinsic_instr *intrin)
{
switch (intrin->intrinsic) {
case nir_intrinsic_ddx:
case nir_intrinsic_ddx_coarse:
case nir_intrinsic_ddx_fine:
case nir_intrinsic_ddy:
case nir_intrinsic_ddy_coarse:
case nir_intrinsic_ddy_fine:
if (intrin->instr.pass_flags)
src_mark_preserve_sz(&intrin->src[0], NULL);
break;
default:
nir_foreach_src(&intrin->instr, src_mark_preserve_sz, NULL);
break;
}
}
static bool
opt_fp_math_ctrl_impl(nir_function_impl *impl)
{
/* Setup pass flags: Store if signed zeros are needed.
* Handle loop header phis here already:
* For now, just disable opts for the back edges.
* That could be improved, but not sure if it's worth it.
*/
nir_foreach_block_reverse(block, impl) {
nir_block *preheader = block_get_loop_preheader(block);
nir_foreach_instr_reverse(instr, block) {
instr->pass_flags = false;
if (instr->type == nir_instr_type_phi && preheader) {
nir_phi_instr *phi = nir_instr_as_phi(instr);
nir_foreach_phi_src(src, phi) {
if (src->pred != preheader)
nir_def_instr(src->src.ssa)->pass_flags = true;
}
}
}
}
struct opt_fp_ctrl_state state = {0};
state.fp_class_state = nir_create_fp_analysis_state(impl);
bool progress = false;
nir_foreach_block_reverse(block, impl) {
nir_foreach_instr_reverse(instr, block) {
switch (instr->type) {
case nir_instr_type_alu:
progress |= opt_alu_fp_math_ctrl(nir_instr_as_alu(instr), &state);
break;
case nir_instr_type_tex:
prop_tex_fp_math_ctrl(nir_instr_as_tex(instr));
break;
case nir_instr_type_intrinsic:
prop_intrin_fp_math_ctrl(nir_instr_as_intrinsic(instr));
break;
case nir_instr_type_phi:
if (!instr->pass_flags)
break;
FALLTHROUGH;
default:
nir_foreach_src(instr, src_mark_preserve_sz, NULL);
break;
}
}
}
nir_free_fp_analysis_state(&state.fp_class_state);
return nir_progress(progress, impl, nir_metadata_all);
}
bool
nir_opt_fp_math_ctrl(nir_shader *shader)
{
bool progress = false;
nir_foreach_function_impl(impl, shader)
progress |= opt_fp_math_ctrl_impl(impl);
return progress;
}