diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index b2e8100b020..97372e9a6d2 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1712,6 +1712,45 @@ nir_alu_instr_channel_used(const nir_alu_instr *instr, unsigned src, bool nir_alu_instr_is_comparison(const nir_alu_instr *instr); +static inline bool +nir_alu_instr_is_mul_add(const nir_alu_instr *instr) +{ + if (!instr) + return false; + + switch (instr->op) { + case nir_op_ffma: + case nir_op_ffma_weak: + case nir_op_fmad: + case nir_op_ffma_old: + return true; + default: + return false; + } +} + +static inline bool +nir_alu_instr_is_mul_add_z(const nir_alu_instr *instr) +{ + if (!instr) + return false; + + switch (instr->op) { + case nir_op_ffmaz: + case nir_op_fmadz: + case nir_op_ffmaz_old: + return true; + default: + return false; + } +} + +static inline bool +nir_alu_instr_is_any_mul_add(const nir_alu_instr *alu) +{ + return nir_alu_instr_is_mul_add(alu) || nir_alu_instr_is_mul_add_z(alu); +} + bool nir_const_value_negative_equal(nir_const_value c1, nir_const_value c2, nir_alu_type full_type); @@ -7282,6 +7321,37 @@ nir_is_io_compact(nir_shader *nir, bool is_output, unsigned location) (nir->info.stage != MESA_SHADER_MESH && location == VARYING_SLOT_TESS_LEVEL_INNER)); } +static inline nir_float_muladd_support +nir_float_muladd_for_bitsize(const nir_shader *nir, unsigned bit_size) +{ + switch (bit_size) { + case 16: + return nir->options->float_mul_add16; + case 32: + return nir->options->float_mul_add32; + case 64: + return nir->options->float_mul_add64; + default: + UNREACHABLE("unsupported bit_size"); + return (nir_float_muladd_support)0; + } +} + +static inline bool +nir_has_ffma(const nir_shader *nir, unsigned bit_size) +{ + nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size); + return (muladd & nir_float_muladd_support_has_ffma) != 0; +} + +static inline bool +nir_prefers_fmad(const nir_shader *nir, unsigned bit_size) +{ + nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size); + return (muladd & nir_float_muladd_support_prefers_split) != 0 || + (muladd & nir_float_muladd_support_has_ffma) == 0; +} + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index f177a77aabd..8422ebe2b56 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -100,6 +100,33 @@ get_float_source(nir_const_value value, unsigned execution_mode, unsigned bit_si return nir_const_value_as_float(value, bit_size); } +/** + * Properly rounds and handles denorms for intermediate results. Useful for + * fused opcodes that want to behave exactly like the unfused variants, e.g. + * fmad. + */ +static double +handle_intermediate_float_result(double value, unsigned execution_mode, unsigned bit_size) +{ + nir_const_value const_val; + switch(bit_size) { + case 64: + const_val.f64 = value; + break; + case 32: + const_val.f32 = value; + break; + case 16: + if (nir_is_rounding_mode_rtz(execution_mode, 16)) { + const_val.u16 = _mesa_float_to_float16_rtz(value); + } else { + const_val.u16 = _mesa_float_to_float16_rtne(value); + } + } + + return get_float_source(const_val, execution_mode, bit_size); +} + /** * Evaluate one component of packSnorm4x8. */ diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index b486e205e55..91f03d0217e 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -1132,6 +1132,66 @@ def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr, [src1_size, src2_size, src3_size], [tuint, tuint, tuint], False, "", const_expr, description) +triop("fmad", tfloat, _2src_commutative, """ +if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { + if (bit_size == 64) + dst = _mesa_double_mul_rtz(src0, src1); + else + dst = _mesa_double_to_float_rtz((double)src0 * (double)src1); +} else { + dst = src0 * src1; +} +dst = handle_intermediate_float_result(dst, execution_mode, bit_size); +if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { + if (bit_size == 64) + dst = _mesa_double_add_rtz(dst, src2); + else + dst = _mesa_double_to_float_rtz((double)dst + (double)src2); +} else { + dst = dst + src2; +} +""", description = "Floating-point unfused multiple-add") + +triop("ffma", tfloat, _2src_commutative, """ +if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { + if (bit_size == 64) + dst = _mesa_double_fma_rtz(src0, src1, src2); + else if (bit_size == 32) + dst = _mesa_float_fma_rtz(src0, src1, src2); + else + dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2)); +} else { + if (bit_size == 32) + dst = fmaf(src0, src1, src2); + else + dst = fma(src0, src1, src2); +} +""", description = "Floating-point fused multiple-add") + +triop("ffma_weak", tfloat, _2src_commutative, """ +if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { + if (bit_size == 64) + dst = _mesa_double_fma_rtz(src0, src1, src2); + else if (bit_size == 32) + dst = _mesa_float_fma_rtz(src0, src1, src2); + else + dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2)); +} else { + if (bit_size == 32) + dst = fmaf(src0, src1, src2); + else + dst = fma(src0, src1, src2); +} +""", description = """ +Floating-point multiple-add that can eitehr be unfused or fused. + +Precise ``ffma_weak`` are required to be either fused or unfused across all +shaders and shader stages where inprecise ``ffma_weak`` doesn't have to remain +consistent not even within the same shader. + +This is like GLSLs ``ffma``. +""") + triop("ffma_old", tfloat, _2src_commutative, """ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { if (bit_size == 64) @@ -1148,6 +1208,44 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) { } """) +triop("fmadz", tfloat32, _2src_commutative, """ +if (src0 == 0.0 || src1 == 0.0) { + dst = 0.0 + src2; +} else { + if (nir_is_rounding_mode_rtz(execution_mode, 32)) + dst = _mesa_double_to_float_rtz((double)src0 * (double)src1); + else + dst = (src0 * src1); + + dst = handle_intermediate_float_result(dst, execution_mode, bit_size); + + if (nir_is_rounding_mode_rtz(execution_mode, 32)) + dst = _mesa_double_to_float_rtz((double)dst + (double)src2); + dst = dst + src2; +} +""", description = """ +Floating-point unfused multiply-add with modified zero handling. + +Unlike :nir:alu-op:`fmad`, anything (even infinity or NaN) multiplied by +/-0.0 is ++0.0. ``fmadz(0.0, inf, src2)`` and ``fmadz(0.0, nan, src2)`` must be +``+0.0 + src2``. +""") + +triop("ffmaz", tfloat32, _2src_commutative, """ +if (src0 == 0.0 || src1 == 0.0) + dst = 0.0 + src2; +else if (nir_is_rounding_mode_rtz(execution_mode, 32)) + dst = _mesa_float_fma_rtz(src0, src1, src2); +else + dst = fmaf(src0, src1, src2); +""", description = """ +Floating-point fused multiply-add with modified zero handling. + +Unlike :nir:alu-op:`ffma`, anything (even infinity or NaN) multiplied by +/-0.0 is ++0.0. ``ffmaz(0.0, inf, src2)`` and ``ffmaz(0.0, nan, src2)`` must be +``+0.0 + src2``. +""") + triop("ffmaz_old", tfloat32, _2src_commutative, """ if (src0 == 0.0 || src1 == 0.0) dst = 0.0 + src2; diff --git a/src/compiler/nir/nir_shader_compiler_options.h b/src/compiler/nir/nir_shader_compiler_options.h index c1e03b87f4f..f03b6be99f8 100644 --- a/src/compiler/nir/nir_shader_compiler_options.h +++ b/src/compiler/nir/nir_shader_compiler_options.h @@ -259,6 +259,20 @@ typedef enum { nir_frag_coord_use_pixel_coord = BITFIELD_BIT(2), } nir_frag_coord_form; +typedef enum { + nir_float_muladd_support_has_ffma = 0x01, + nir_float_muladd_support_has_fmad = 0x02, + + /** Strongly hints that fmad or fmul+fadd is preferred over ffma */ + nir_float_muladd_support_prefers_split = 0x04, + + /** ffma_weak won't be lowered */ + nir_float_muladd_support_keep_weak_ffma = 0x08, + + nir_float_muladd_support_fuse = 0x10, +} nir_float_muladd_support; +MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_float_muladd_support) + typedef struct nir_shader_compiler_options { bool lower_fdiv; bool lower_ffma16; @@ -267,6 +281,9 @@ typedef struct nir_shader_compiler_options { bool fuse_ffma16; bool fuse_ffma32; bool fuse_ffma64; + nir_float_muladd_support float_mul_add16; + nir_float_muladd_support float_mul_add32; + nir_float_muladd_support float_mul_add64; bool lower_flrp16; bool lower_flrp32; /** Lowers flrp when it does not support doubles */