vtn: fuse OpenCL mad if we can can

clpeak "float" case from 1112 -> 1978 GFLOPS on rusticl on m1.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26932>
This commit is contained in:
Alyssa Rosenzweig 2024-01-08 12:01:49 -04:00 committed by Marge Bot
parent f6d2df5a75
commit 3da2773316

View file

@ -508,8 +508,25 @@ handle_special(struct vtn_builder *b, uint32_t opcode,
return nir_cross3(nb, srcs[0], srcs[1]);
case OpenCLstd_Fdim:
return nir_fdim(nb, srcs[0], srcs[1]);
case OpenCLstd_Mad:
return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
case OpenCLstd_Mad: {
/* The spec says mad is
*
* Implemented either as a correctly rounded fma or as a multiply
* followed by an add both of which are correctly rounded
*
* So lower to fmul+fadd if we have to, but fuse to an ffma if the backend
* supports that. This can be significantly faster.
*/
bool lower =
((nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16) ||
(nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) ||
(nb->shader->options->lower_ffma64 && srcs[0]->bit_size == 64));
if (lower)
return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
else
return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
}
case OpenCLstd_Maxmag:
return nir_maxmag(nb, srcs[0], srcs[1]);
case OpenCLstd_Minmag: