mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-25 10:28:11 +02:00
nir: add new float multiply-add opcodes
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41165>
This commit is contained in:
parent
0251be32cb
commit
32e91a7467
4 changed files with 212 additions and 0 deletions
|
|
@ -1712,6 +1712,45 @@ nir_alu_instr_channel_used(const nir_alu_instr *instr, unsigned src,
|
|||
bool
|
||||
nir_alu_instr_is_comparison(const nir_alu_instr *instr);
|
||||
|
||||
static inline bool
|
||||
nir_alu_instr_is_mul_add(const nir_alu_instr *instr)
|
||||
{
|
||||
if (!instr)
|
||||
return false;
|
||||
|
||||
switch (instr->op) {
|
||||
case nir_op_ffma:
|
||||
case nir_op_ffma_weak:
|
||||
case nir_op_fmad:
|
||||
case nir_op_ffma_old:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool
|
||||
nir_alu_instr_is_mul_add_z(const nir_alu_instr *instr)
|
||||
{
|
||||
if (!instr)
|
||||
return false;
|
||||
|
||||
switch (instr->op) {
|
||||
case nir_op_ffmaz:
|
||||
case nir_op_fmadz:
|
||||
case nir_op_ffmaz_old:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool
|
||||
nir_alu_instr_is_any_mul_add(const nir_alu_instr *alu)
|
||||
{
|
||||
return nir_alu_instr_is_mul_add(alu) || nir_alu_instr_is_mul_add_z(alu);
|
||||
}
|
||||
|
||||
bool nir_const_value_negative_equal(nir_const_value c1, nir_const_value c2,
|
||||
nir_alu_type full_type);
|
||||
|
||||
|
|
@ -7282,6 +7321,37 @@ nir_is_io_compact(nir_shader *nir, bool is_output, unsigned location)
|
|||
(nir->info.stage != MESA_SHADER_MESH && location == VARYING_SLOT_TESS_LEVEL_INNER));
|
||||
}
|
||||
|
||||
static inline nir_float_muladd_support
|
||||
nir_float_muladd_for_bitsize(const nir_shader *nir, unsigned bit_size)
|
||||
{
|
||||
switch (bit_size) {
|
||||
case 16:
|
||||
return nir->options->float_mul_add16;
|
||||
case 32:
|
||||
return nir->options->float_mul_add32;
|
||||
case 64:
|
||||
return nir->options->float_mul_add64;
|
||||
default:
|
||||
UNREACHABLE("unsupported bit_size");
|
||||
return (nir_float_muladd_support)0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool
|
||||
nir_has_ffma(const nir_shader *nir, unsigned bit_size)
|
||||
{
|
||||
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
|
||||
return (muladd & nir_float_muladd_support_has_ffma) != 0;
|
||||
}
|
||||
|
||||
static inline bool
|
||||
nir_prefers_fmad(const nir_shader *nir, unsigned bit_size)
|
||||
{
|
||||
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
|
||||
return (muladd & nir_float_muladd_support_prefers_split) != 0 ||
|
||||
(muladd & nir_float_muladd_support_has_ffma) == 0;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -100,6 +100,33 @@ get_float_source(nir_const_value value, unsigned execution_mode, unsigned bit_si
|
|||
return nir_const_value_as_float(value, bit_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* Properly rounds and handles denorms for intermediate results. Useful for
|
||||
* fused opcodes that want to behave exactly like the unfused variants, e.g.
|
||||
* fmad.
|
||||
*/
|
||||
static double
|
||||
handle_intermediate_float_result(double value, unsigned execution_mode, unsigned bit_size)
|
||||
{
|
||||
nir_const_value const_val;
|
||||
switch(bit_size) {
|
||||
case 64:
|
||||
const_val.f64 = value;
|
||||
break;
|
||||
case 32:
|
||||
const_val.f32 = value;
|
||||
break;
|
||||
case 16:
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
|
||||
const_val.u16 = _mesa_float_to_float16_rtz(value);
|
||||
} else {
|
||||
const_val.u16 = _mesa_float_to_float16_rtne(value);
|
||||
}
|
||||
}
|
||||
|
||||
return get_float_source(const_val, execution_mode, bit_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate one component of packSnorm4x8.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -1132,6 +1132,66 @@ def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr,
|
|||
[src1_size, src2_size, src3_size],
|
||||
[tuint, tuint, tuint], False, "", const_expr, description)
|
||||
|
||||
triop("fmad", tfloat, _2src_commutative, """
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||
if (bit_size == 64)
|
||||
dst = _mesa_double_mul_rtz(src0, src1);
|
||||
else
|
||||
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
|
||||
} else {
|
||||
dst = src0 * src1;
|
||||
}
|
||||
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||
if (bit_size == 64)
|
||||
dst = _mesa_double_add_rtz(dst, src2);
|
||||
else
|
||||
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
|
||||
} else {
|
||||
dst = dst + src2;
|
||||
}
|
||||
""", description = "Floating-point unfused multiple-add")
|
||||
|
||||
triop("ffma", tfloat, _2src_commutative, """
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||
if (bit_size == 64)
|
||||
dst = _mesa_double_fma_rtz(src0, src1, src2);
|
||||
else if (bit_size == 32)
|
||||
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||
else
|
||||
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
|
||||
} else {
|
||||
if (bit_size == 32)
|
||||
dst = fmaf(src0, src1, src2);
|
||||
else
|
||||
dst = fma(src0, src1, src2);
|
||||
}
|
||||
""", description = "Floating-point fused multiple-add")
|
||||
|
||||
triop("ffma_weak", tfloat, _2src_commutative, """
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||
if (bit_size == 64)
|
||||
dst = _mesa_double_fma_rtz(src0, src1, src2);
|
||||
else if (bit_size == 32)
|
||||
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||
else
|
||||
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
|
||||
} else {
|
||||
if (bit_size == 32)
|
||||
dst = fmaf(src0, src1, src2);
|
||||
else
|
||||
dst = fma(src0, src1, src2);
|
||||
}
|
||||
""", description = """
|
||||
Floating-point multiple-add that can eitehr be unfused or fused.
|
||||
|
||||
Precise ``ffma_weak`` are required to be either fused or unfused across all
|
||||
shaders and shader stages where inprecise ``ffma_weak`` doesn't have to remain
|
||||
consistent not even within the same shader.
|
||||
|
||||
This is like GLSLs ``ffma``.
|
||||
""")
|
||||
|
||||
triop("ffma_old", tfloat, _2src_commutative, """
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||
if (bit_size == 64)
|
||||
|
|
@ -1148,6 +1208,44 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
|||
}
|
||||
""")
|
||||
|
||||
triop("fmadz", tfloat32, _2src_commutative, """
|
||||
if (src0 == 0.0 || src1 == 0.0) {
|
||||
dst = 0.0 + src2;
|
||||
} else {
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
|
||||
else
|
||||
dst = (src0 * src1);
|
||||
|
||||
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
|
||||
|
||||
if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
|
||||
dst = dst + src2;
|
||||
}
|
||||
""", description = """
|
||||
Floating-point unfused multiply-add with modified zero handling.
|
||||
|
||||
Unlike :nir:alu-op:`fmad`, anything (even infinity or NaN) multiplied by +/-0.0 is
|
||||
+0.0. ``fmadz(0.0, inf, src2)`` and ``fmadz(0.0, nan, src2)`` must be
|
||||
``+0.0 + src2``.
|
||||
""")
|
||||
|
||||
triop("ffmaz", tfloat32, _2src_commutative, """
|
||||
if (src0 == 0.0 || src1 == 0.0)
|
||||
dst = 0.0 + src2;
|
||||
else if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||
else
|
||||
dst = fmaf(src0, src1, src2);
|
||||
""", description = """
|
||||
Floating-point fused multiply-add with modified zero handling.
|
||||
|
||||
Unlike :nir:alu-op:`ffma`, anything (even infinity or NaN) multiplied by +/-0.0 is
|
||||
+0.0. ``ffmaz(0.0, inf, src2)`` and ``ffmaz(0.0, nan, src2)`` must be
|
||||
``+0.0 + src2``.
|
||||
""")
|
||||
|
||||
triop("ffmaz_old", tfloat32, _2src_commutative, """
|
||||
if (src0 == 0.0 || src1 == 0.0)
|
||||
dst = 0.0 + src2;
|
||||
|
|
|
|||
|
|
@ -259,6 +259,20 @@ typedef enum {
|
|||
nir_frag_coord_use_pixel_coord = BITFIELD_BIT(2),
|
||||
} nir_frag_coord_form;
|
||||
|
||||
typedef enum {
|
||||
nir_float_muladd_support_has_ffma = 0x01,
|
||||
nir_float_muladd_support_has_fmad = 0x02,
|
||||
|
||||
/** Strongly hints that fmad or fmul+fadd is preferred over ffma */
|
||||
nir_float_muladd_support_prefers_split = 0x04,
|
||||
|
||||
/** ffma_weak won't be lowered */
|
||||
nir_float_muladd_support_keep_weak_ffma = 0x08,
|
||||
|
||||
nir_float_muladd_support_fuse = 0x10,
|
||||
} nir_float_muladd_support;
|
||||
MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_float_muladd_support)
|
||||
|
||||
typedef struct nir_shader_compiler_options {
|
||||
bool lower_fdiv;
|
||||
bool lower_ffma16;
|
||||
|
|
@ -267,6 +281,9 @@ typedef struct nir_shader_compiler_options {
|
|||
bool fuse_ffma16;
|
||||
bool fuse_ffma32;
|
||||
bool fuse_ffma64;
|
||||
nir_float_muladd_support float_mul_add16;
|
||||
nir_float_muladd_support float_mul_add32;
|
||||
nir_float_muladd_support float_mul_add64;
|
||||
bool lower_flrp16;
|
||||
bool lower_flrp32;
|
||||
/** Lowers flrp when it does not support doubles */
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue