mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-06-04 15:18:15 +02:00
nir: add new float multiply-add opcodes
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41165>
This commit is contained in:
parent
0251be32cb
commit
32e91a7467
4 changed files with 212 additions and 0 deletions
|
|
@ -1712,6 +1712,45 @@ nir_alu_instr_channel_used(const nir_alu_instr *instr, unsigned src,
|
||||||
bool
|
bool
|
||||||
nir_alu_instr_is_comparison(const nir_alu_instr *instr);
|
nir_alu_instr_is_comparison(const nir_alu_instr *instr);
|
||||||
|
|
||||||
|
static inline bool
|
||||||
|
nir_alu_instr_is_mul_add(const nir_alu_instr *instr)
|
||||||
|
{
|
||||||
|
if (!instr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
switch (instr->op) {
|
||||||
|
case nir_op_ffma:
|
||||||
|
case nir_op_ffma_weak:
|
||||||
|
case nir_op_fmad:
|
||||||
|
case nir_op_ffma_old:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool
|
||||||
|
nir_alu_instr_is_mul_add_z(const nir_alu_instr *instr)
|
||||||
|
{
|
||||||
|
if (!instr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
switch (instr->op) {
|
||||||
|
case nir_op_ffmaz:
|
||||||
|
case nir_op_fmadz:
|
||||||
|
case nir_op_ffmaz_old:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool
|
||||||
|
nir_alu_instr_is_any_mul_add(const nir_alu_instr *alu)
|
||||||
|
{
|
||||||
|
return nir_alu_instr_is_mul_add(alu) || nir_alu_instr_is_mul_add_z(alu);
|
||||||
|
}
|
||||||
|
|
||||||
bool nir_const_value_negative_equal(nir_const_value c1, nir_const_value c2,
|
bool nir_const_value_negative_equal(nir_const_value c1, nir_const_value c2,
|
||||||
nir_alu_type full_type);
|
nir_alu_type full_type);
|
||||||
|
|
||||||
|
|
@ -7282,6 +7321,37 @@ nir_is_io_compact(nir_shader *nir, bool is_output, unsigned location)
|
||||||
(nir->info.stage != MESA_SHADER_MESH && location == VARYING_SLOT_TESS_LEVEL_INNER));
|
(nir->info.stage != MESA_SHADER_MESH && location == VARYING_SLOT_TESS_LEVEL_INNER));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline nir_float_muladd_support
|
||||||
|
nir_float_muladd_for_bitsize(const nir_shader *nir, unsigned bit_size)
|
||||||
|
{
|
||||||
|
switch (bit_size) {
|
||||||
|
case 16:
|
||||||
|
return nir->options->float_mul_add16;
|
||||||
|
case 32:
|
||||||
|
return nir->options->float_mul_add32;
|
||||||
|
case 64:
|
||||||
|
return nir->options->float_mul_add64;
|
||||||
|
default:
|
||||||
|
UNREACHABLE("unsupported bit_size");
|
||||||
|
return (nir_float_muladd_support)0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool
|
||||||
|
nir_has_ffma(const nir_shader *nir, unsigned bit_size)
|
||||||
|
{
|
||||||
|
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
|
||||||
|
return (muladd & nir_float_muladd_support_has_ffma) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool
|
||||||
|
nir_prefers_fmad(const nir_shader *nir, unsigned bit_size)
|
||||||
|
{
|
||||||
|
nir_float_muladd_support muladd = nir_float_muladd_for_bitsize(nir, bit_size);
|
||||||
|
return (muladd & nir_float_muladd_support_prefers_split) != 0 ||
|
||||||
|
(muladd & nir_float_muladd_support_has_ffma) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* extern "C" */
|
} /* extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,33 @@ get_float_source(nir_const_value value, unsigned execution_mode, unsigned bit_si
|
||||||
return nir_const_value_as_float(value, bit_size);
|
return nir_const_value_as_float(value, bit_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Properly rounds and handles denorms for intermediate results. Useful for
|
||||||
|
* fused opcodes that want to behave exactly like the unfused variants, e.g.
|
||||||
|
* fmad.
|
||||||
|
*/
|
||||||
|
static double
|
||||||
|
handle_intermediate_float_result(double value, unsigned execution_mode, unsigned bit_size)
|
||||||
|
{
|
||||||
|
nir_const_value const_val;
|
||||||
|
switch(bit_size) {
|
||||||
|
case 64:
|
||||||
|
const_val.f64 = value;
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
const_val.f32 = value;
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
|
||||||
|
const_val.u16 = _mesa_float_to_float16_rtz(value);
|
||||||
|
} else {
|
||||||
|
const_val.u16 = _mesa_float_to_float16_rtne(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return get_float_source(const_val, execution_mode, bit_size);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Evaluate one component of packSnorm4x8.
|
* Evaluate one component of packSnorm4x8.
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -1132,6 +1132,66 @@ def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr,
|
||||||
[src1_size, src2_size, src3_size],
|
[src1_size, src2_size, src3_size],
|
||||||
[tuint, tuint, tuint], False, "", const_expr, description)
|
[tuint, tuint, tuint], False, "", const_expr, description)
|
||||||
|
|
||||||
|
triop("fmad", tfloat, _2src_commutative, """
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
|
if (bit_size == 64)
|
||||||
|
dst = _mesa_double_mul_rtz(src0, src1);
|
||||||
|
else
|
||||||
|
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
|
||||||
|
} else {
|
||||||
|
dst = src0 * src1;
|
||||||
|
}
|
||||||
|
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
|
if (bit_size == 64)
|
||||||
|
dst = _mesa_double_add_rtz(dst, src2);
|
||||||
|
else
|
||||||
|
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
|
||||||
|
} else {
|
||||||
|
dst = dst + src2;
|
||||||
|
}
|
||||||
|
""", description = "Floating-point unfused multiple-add")
|
||||||
|
|
||||||
|
triop("ffma", tfloat, _2src_commutative, """
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
|
if (bit_size == 64)
|
||||||
|
dst = _mesa_double_fma_rtz(src0, src1, src2);
|
||||||
|
else if (bit_size == 32)
|
||||||
|
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||||
|
else
|
||||||
|
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
|
||||||
|
} else {
|
||||||
|
if (bit_size == 32)
|
||||||
|
dst = fmaf(src0, src1, src2);
|
||||||
|
else
|
||||||
|
dst = fma(src0, src1, src2);
|
||||||
|
}
|
||||||
|
""", description = "Floating-point fused multiple-add")
|
||||||
|
|
||||||
|
triop("ffma_weak", tfloat, _2src_commutative, """
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
|
if (bit_size == 64)
|
||||||
|
dst = _mesa_double_fma_rtz(src0, src1, src2);
|
||||||
|
else if (bit_size == 32)
|
||||||
|
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||||
|
else
|
||||||
|
dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
|
||||||
|
} else {
|
||||||
|
if (bit_size == 32)
|
||||||
|
dst = fmaf(src0, src1, src2);
|
||||||
|
else
|
||||||
|
dst = fma(src0, src1, src2);
|
||||||
|
}
|
||||||
|
""", description = """
|
||||||
|
Floating-point multiple-add that can eitehr be unfused or fused.
|
||||||
|
|
||||||
|
Precise ``ffma_weak`` are required to be either fused or unfused across all
|
||||||
|
shaders and shader stages where inprecise ``ffma_weak`` doesn't have to remain
|
||||||
|
consistent not even within the same shader.
|
||||||
|
|
||||||
|
This is like GLSLs ``ffma``.
|
||||||
|
""")
|
||||||
|
|
||||||
triop("ffma_old", tfloat, _2src_commutative, """
|
triop("ffma_old", tfloat, _2src_commutative, """
|
||||||
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
if (bit_size == 64)
|
if (bit_size == 64)
|
||||||
|
|
@ -1148,6 +1208,44 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
triop("fmadz", tfloat32, _2src_commutative, """
|
||||||
|
if (src0 == 0.0 || src1 == 0.0) {
|
||||||
|
dst = 0.0 + src2;
|
||||||
|
} else {
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||||
|
dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
|
||||||
|
else
|
||||||
|
dst = (src0 * src1);
|
||||||
|
|
||||||
|
dst = handle_intermediate_float_result(dst, execution_mode, bit_size);
|
||||||
|
|
||||||
|
if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||||
|
dst = _mesa_double_to_float_rtz((double)dst + (double)src2);
|
||||||
|
dst = dst + src2;
|
||||||
|
}
|
||||||
|
""", description = """
|
||||||
|
Floating-point unfused multiply-add with modified zero handling.
|
||||||
|
|
||||||
|
Unlike :nir:alu-op:`fmad`, anything (even infinity or NaN) multiplied by +/-0.0 is
|
||||||
|
+0.0. ``fmadz(0.0, inf, src2)`` and ``fmadz(0.0, nan, src2)`` must be
|
||||||
|
``+0.0 + src2``.
|
||||||
|
""")
|
||||||
|
|
||||||
|
triop("ffmaz", tfloat32, _2src_commutative, """
|
||||||
|
if (src0 == 0.0 || src1 == 0.0)
|
||||||
|
dst = 0.0 + src2;
|
||||||
|
else if (nir_is_rounding_mode_rtz(execution_mode, 32))
|
||||||
|
dst = _mesa_float_fma_rtz(src0, src1, src2);
|
||||||
|
else
|
||||||
|
dst = fmaf(src0, src1, src2);
|
||||||
|
""", description = """
|
||||||
|
Floating-point fused multiply-add with modified zero handling.
|
||||||
|
|
||||||
|
Unlike :nir:alu-op:`ffma`, anything (even infinity or NaN) multiplied by +/-0.0 is
|
||||||
|
+0.0. ``ffmaz(0.0, inf, src2)`` and ``ffmaz(0.0, nan, src2)`` must be
|
||||||
|
``+0.0 + src2``.
|
||||||
|
""")
|
||||||
|
|
||||||
triop("ffmaz_old", tfloat32, _2src_commutative, """
|
triop("ffmaz_old", tfloat32, _2src_commutative, """
|
||||||
if (src0 == 0.0 || src1 == 0.0)
|
if (src0 == 0.0 || src1 == 0.0)
|
||||||
dst = 0.0 + src2;
|
dst = 0.0 + src2;
|
||||||
|
|
|
||||||
|
|
@ -259,6 +259,20 @@ typedef enum {
|
||||||
nir_frag_coord_use_pixel_coord = BITFIELD_BIT(2),
|
nir_frag_coord_use_pixel_coord = BITFIELD_BIT(2),
|
||||||
} nir_frag_coord_form;
|
} nir_frag_coord_form;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
nir_float_muladd_support_has_ffma = 0x01,
|
||||||
|
nir_float_muladd_support_has_fmad = 0x02,
|
||||||
|
|
||||||
|
/** Strongly hints that fmad or fmul+fadd is preferred over ffma */
|
||||||
|
nir_float_muladd_support_prefers_split = 0x04,
|
||||||
|
|
||||||
|
/** ffma_weak won't be lowered */
|
||||||
|
nir_float_muladd_support_keep_weak_ffma = 0x08,
|
||||||
|
|
||||||
|
nir_float_muladd_support_fuse = 0x10,
|
||||||
|
} nir_float_muladd_support;
|
||||||
|
MESA_DEFINE_CPP_ENUM_BITFIELD_OPERATORS(nir_float_muladd_support)
|
||||||
|
|
||||||
typedef struct nir_shader_compiler_options {
|
typedef struct nir_shader_compiler_options {
|
||||||
bool lower_fdiv;
|
bool lower_fdiv;
|
||||||
bool lower_ffma16;
|
bool lower_ffma16;
|
||||||
|
|
@ -267,6 +281,9 @@ typedef struct nir_shader_compiler_options {
|
||||||
bool fuse_ffma16;
|
bool fuse_ffma16;
|
||||||
bool fuse_ffma32;
|
bool fuse_ffma32;
|
||||||
bool fuse_ffma64;
|
bool fuse_ffma64;
|
||||||
|
nir_float_muladd_support float_mul_add16;
|
||||||
|
nir_float_muladd_support float_mul_add32;
|
||||||
|
nir_float_muladd_support float_mul_add64;
|
||||||
bool lower_flrp16;
|
bool lower_flrp16;
|
||||||
bool lower_flrp32;
|
bool lower_flrp32;
|
||||||
/** Lowers flrp when it does not support doubles */
|
/** Lowers flrp when it does not support doubles */
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue