mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-11 15:40:31 +01:00
nir: add a pass to optimize fp_math_ctrl
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40098>
This commit is contained in:
parent
042ee8dafc
commit
7c217e540c
3 changed files with 305 additions and 0 deletions
|
|
@ -268,6 +268,7 @@ else
|
|||
'nir_opt_dead_cf.c',
|
||||
'nir_opt_dead_write_vars.c',
|
||||
'nir_opt_find_array_copies.c',
|
||||
'nir_opt_fp_math_ctrl.c',
|
||||
'nir_opt_frag_coord_to_pixel_coord.c',
|
||||
'nir_opt_fragdepth.c',
|
||||
'nir_opt_gcm.c',
|
||||
|
|
|
|||
|
|
@ -6502,6 +6502,8 @@ bool nir_opt_algebraic_integer_promotion(nir_shader *shader);
|
|||
bool nir_opt_reassociate_matrix_mul(nir_shader *shader);
|
||||
bool nir_opt_constant_folding(nir_shader *shader);
|
||||
|
||||
bool nir_opt_fp_math_ctrl(nir_shader *shader);
|
||||
|
||||
nir_def *nir_try_constant_fold_alu(nir_builder *b, nir_alu_instr *alu);
|
||||
|
||||
/* Try to combine a and b into a. Return true if combination was possible,
|
||||
|
|
|
|||
302
src/compiler/nir/nir_opt_fp_math_ctrl.c
Normal file
302
src/compiler/nir/nir_opt_fp_math_ctrl.c
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* Copyright 2026 Valve Corporation
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
|
||||
#include "nir.h"
|
||||
#include "nir_range_analysis.h"
|
||||
|
||||
/* Applications like DXVK spam float control restriction flags on all ALU instructions.
|
||||
* Remove signed zero, Inf, NaN preserve flags if we can prove that no inputs/outputs
|
||||
* are zero/inf/NaN using range analysis.
|
||||
* For signed zero, we can go a step further by back propagating when signed zero are not
|
||||
* needed, which is a quite common. For example, any float comparison, cosinus, exp2, log2,
|
||||
* or addition with non zero value does not care about the zero sign of the inputs. Neither
|
||||
* do texture coordinates.
|
||||
*
|
||||
* Future work could also consider fragment output state, fixed point or R11G11B10 formats
|
||||
* do not care about the sign of zero.
|
||||
* For pre raster stages, position doesn't care, and we could back propagate information from
|
||||
* the FS for varyings, and interpolated varyings do not care anyway.
|
||||
*/
|
||||
|
||||
struct opt_fp_ctrl_state {
|
||||
nir_fp_analysis_state fp_class_state;
|
||||
};
|
||||
|
||||
static inline nir_block *
|
||||
block_get_loop_preheader(nir_block *block)
|
||||
{
|
||||
nir_cf_node *parent = block->cf_node.parent;
|
||||
if (parent->type != nir_cf_node_loop)
|
||||
return NULL;
|
||||
if (block != nir_cf_node_cf_tree_first(parent))
|
||||
return NULL;
|
||||
return nir_cf_node_as_block(nir_cf_node_prev(parent));
|
||||
}
|
||||
|
||||
|
||||
static bool
|
||||
src_mark_preserve_sz(nir_src *src, UNUSED void *state)
|
||||
{
|
||||
nir_def_instr(src->ssa)->pass_flags = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
can_prop_nsz(nir_alu_instr *alu)
|
||||
{
|
||||
/* Only divide cares about the sign of zero even when the sign of zero
|
||||
* of the output doesn't matter.
|
||||
*/
|
||||
switch (alu->op) {
|
||||
case nir_op_fdiv:
|
||||
case nir_op_frcp:
|
||||
case nir_op_frsq:
|
||||
case nir_op_fcopysign_pco:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
opt_alu_fp_math_ctrl(nir_alu_instr *alu, struct opt_fp_ctrl_state *state)
|
||||
{
|
||||
if (alu->op == nir_op_bcsel) {
|
||||
src_mark_preserve_sz(&alu->src[0].src, NULL);
|
||||
|
||||
if (alu->instr.pass_flags) {
|
||||
src_mark_preserve_sz(&alu->src[1].src, NULL);
|
||||
src_mark_preserve_sz(&alu->src[2].src, NULL);
|
||||
}
|
||||
return false;
|
||||
} else if (nir_op_is_vec_or_mov(alu->op) && !alu->instr.pass_flags) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const nir_op_info *op_info = &nir_op_infos[(int)alu->op];
|
||||
unsigned old_fp_math_ctrl = alu->fp_math_ctrl;
|
||||
if (alu->fp_math_ctrl & nir_fp_preserve_sz_inf_nan) {
|
||||
fp_class_mask class_mask = 0;
|
||||
|
||||
bool dest_is_float = nir_alu_type_get_base_type(op_info->output_type) == nir_type_float;
|
||||
|
||||
if (dest_is_float) {
|
||||
class_mask |= nir_analyze_fp_class(&state->fp_class_state, &alu->def);
|
||||
if (can_prop_nsz(alu) && (!alu->instr.pass_flags || !(class_mask & FP_CLASS_ANY_ZERO)))
|
||||
alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < op_info->num_inputs; i++) {
|
||||
if (nir_alu_type_get_base_type(op_info->input_types[i]) == nir_type_float)
|
||||
class_mask |= nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
|
||||
}
|
||||
|
||||
/* If class_mask is 0, the opcode has no float operands/definition,
|
||||
* So it must be a special opcode, or operate on bfloats. Ignore these.
|
||||
*/
|
||||
if (class_mask) {
|
||||
/* If none of the float operands or the definition can be zero/Inf/NaN,
|
||||
* remove the matching fp_math_ctrl flag.
|
||||
*/
|
||||
if (!(class_mask & FP_CLASS_ANY_ZERO))
|
||||
alu->fp_math_ctrl &= ~nir_fp_preserve_signed_zero;
|
||||
if (!(class_mask & FP_CLASS_ANY_INF))
|
||||
alu->fp_math_ctrl &= ~nir_fp_preserve_inf;
|
||||
if (!(class_mask & FP_CLASS_NAN))
|
||||
alu->fp_math_ctrl &= ~nir_fp_preserve_nan;
|
||||
}
|
||||
}
|
||||
|
||||
if (alu->fp_math_ctrl & nir_fp_preserve_signed_zero) {
|
||||
/* Some alu never cares about the input sign of zero. */
|
||||
switch (alu->op) {
|
||||
case nir_op_fabs:
|
||||
case nir_op_fsat:
|
||||
case nir_op_fexp2:
|
||||
case nir_op_flog2:
|
||||
case nir_op_fcos:
|
||||
case nir_op_fcos_amd:
|
||||
case nir_op_fmulz:
|
||||
break;
|
||||
case nir_op_fmin: {
|
||||
bool had_neg_zero = false;
|
||||
for (unsigned i = 0; i < 2; i++) {
|
||||
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
|
||||
|
||||
if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_POS | FP_CLASS_POS_ZERO)) {
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
} else if (fp_class & FP_CLASS_NEG_ZERO) {
|
||||
/* If both operands can be -0.0, at least one needs to be preserved. */
|
||||
if (had_neg_zero)
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
had_neg_zero = true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_fmax: {
|
||||
bool had_pos_zero = false;
|
||||
for (unsigned i = 0; i < 2; i++) {
|
||||
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
|
||||
|
||||
if (fp_class & (FP_CLASS_NAN | FP_CLASS_ANY_NEG | FP_CLASS_NEG_ZERO)) {
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
} else if (fp_class & FP_CLASS_POS_ZERO) {
|
||||
/* If both operands can be +0.0, at least one needs to be preserved. */
|
||||
if (had_pos_zero)
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
had_pos_zero = true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_op_fsub:
|
||||
case nir_op_fadd: {
|
||||
bool had_pos_zero = false;
|
||||
for (unsigned i = 0; i < 2; i++) {
|
||||
fp_class_mask fp_class = nir_analyze_fp_class(&state->fp_class_state, alu->src[i].src.ssa);
|
||||
|
||||
bool negate = i == 1 && alu->op == nir_op_fsub;
|
||||
|
||||
if (fp_class & (negate ? FP_CLASS_POS_ZERO : FP_CLASS_NEG_ZERO)) {
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
} else if (fp_class & (negate ? FP_CLASS_NEG_ZERO : FP_CLASS_POS_ZERO)) {
|
||||
/* If both operands can be +0.0, at least one needs to be preserved. */
|
||||
if (had_pos_zero)
|
||||
src_mark_preserve_sz(&alu->src[!i].src, NULL);
|
||||
had_pos_zero = true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_ffmaz:
|
||||
src_mark_preserve_sz(&alu->src[2].src, NULL);
|
||||
break;
|
||||
case nir_op_ffma:
|
||||
if ((nir_analyze_fp_class(&state->fp_class_state, alu->src[2].src.ssa) & FP_CLASS_NEG_ZERO) &&
|
||||
!nir_alu_srcs_equal(alu, alu, 0, 1)) {
|
||||
src_mark_preserve_sz(&alu->src[0].src, NULL);
|
||||
src_mark_preserve_sz(&alu->src[1].src, NULL);
|
||||
}
|
||||
src_mark_preserve_sz(&alu->src[2].src, NULL);
|
||||
break;
|
||||
case nir_op_fmul:
|
||||
if (nir_alu_srcs_equal(alu, alu, 0, 1))
|
||||
break;
|
||||
FALLTHROUGH;
|
||||
default:
|
||||
for (unsigned i = 0; i < op_info->num_inputs; i++)
|
||||
src_mark_preserve_sz(&alu->src[i].src, NULL);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
/* Only preserve signed zeros for non float operands. */
|
||||
for (unsigned i = 0; i < op_info->num_inputs; i++) {
|
||||
if (nir_alu_type_get_base_type(op_info->input_types[i]) != nir_type_float)
|
||||
src_mark_preserve_sz(&alu->src[i].src, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
return alu->fp_math_ctrl != old_fp_math_ctrl;
|
||||
}
|
||||
|
||||
static void
|
||||
prop_tex_fp_math_ctrl(nir_tex_instr *tex)
|
||||
{
|
||||
for (unsigned i = 0; i < tex->num_srcs; i++) {
|
||||
/* Floating point tex sources don't care about sign of zero. */
|
||||
if (nir_tex_instr_src_type(tex, i) != nir_type_float)
|
||||
src_mark_preserve_sz(&tex->src[i].src, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
prop_intrin_fp_math_ctrl(nir_intrinsic_instr *intrin)
|
||||
{
|
||||
switch (intrin->intrinsic) {
|
||||
case nir_intrinsic_ddx:
|
||||
case nir_intrinsic_ddx_coarse:
|
||||
case nir_intrinsic_ddx_fine:
|
||||
case nir_intrinsic_ddy:
|
||||
case nir_intrinsic_ddy_coarse:
|
||||
case nir_intrinsic_ddy_fine:
|
||||
if (intrin->instr.pass_flags)
|
||||
src_mark_preserve_sz(&intrin->src[0], NULL);
|
||||
break;
|
||||
default:
|
||||
nir_foreach_src(&intrin->instr, src_mark_preserve_sz, NULL);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
opt_fp_math_ctrl_impl(nir_function_impl *impl)
|
||||
{
|
||||
/* Setup pass flags: Store if signed zeros are needed.
|
||||
* Handle loop header phis here already:
|
||||
* For now, just disable opts for the back edges.
|
||||
* That could be improved, but not sure if it's worth it.
|
||||
*/
|
||||
nir_foreach_block_reverse(block, impl) {
|
||||
nir_block *preheader = block_get_loop_preheader(block);
|
||||
|
||||
nir_foreach_instr_reverse(instr, block) {
|
||||
instr->pass_flags = false;
|
||||
if (instr->type == nir_instr_type_phi && preheader) {
|
||||
nir_phi_instr *phi = nir_instr_as_phi(instr);
|
||||
nir_foreach_phi_src(src, phi) {
|
||||
if (src->pred != preheader)
|
||||
nir_def_instr(src->src.ssa)->pass_flags = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct opt_fp_ctrl_state state = {0};
|
||||
state.fp_class_state = nir_create_fp_analysis_state(impl);
|
||||
|
||||
bool progress = false;
|
||||
|
||||
nir_foreach_block_reverse(block, impl) {
|
||||
nir_foreach_instr_reverse(instr, block) {
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_alu:
|
||||
progress |= opt_alu_fp_math_ctrl(nir_instr_as_alu(instr), &state);
|
||||
break;
|
||||
case nir_instr_type_tex:
|
||||
prop_tex_fp_math_ctrl(nir_instr_as_tex(instr));
|
||||
break;
|
||||
case nir_instr_type_intrinsic:
|
||||
prop_intrin_fp_math_ctrl(nir_instr_as_intrinsic(instr));
|
||||
break;
|
||||
case nir_instr_type_phi:
|
||||
if (!instr->pass_flags)
|
||||
break;
|
||||
FALLTHROUGH;
|
||||
default:
|
||||
nir_foreach_src(instr, src_mark_preserve_sz, NULL);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nir_free_fp_analysis_state(&state.fp_class_state);
|
||||
|
||||
return nir_progress(progress, impl, nir_metadata_all);
|
||||
}
|
||||
|
||||
bool
|
||||
nir_opt_fp_math_ctrl(nir_shader *shader)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
nir_foreach_function_impl(impl, shader)
|
||||
progress |= opt_fp_math_ctrl_impl(impl);
|
||||
|
||||
return progress;
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue