mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-05 00:58:05 +02:00
Revert "spirv: Use a simpler and more correct implementaiton of tanh()"
This reverts commitda1c49171d. The reduced formula has precision problems on fp16 around 0. Bring back the old formula, but make sure to keep the clamping. Tested-by: Marge Bot <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4054> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4054> (cherry picked from commit9f9432d56c)
This commit is contained in:
parent
a9ee554df3
commit
45e14cbdc7
2 changed files with 14 additions and 15 deletions
|
|
@ -31,7 +31,7 @@
|
|||
"description": "Revert \"spirv: Use a simpler and more correct implementaiton of tanh()\"",
|
||||
"nominated": true,
|
||||
"nomination_type": 2,
|
||||
"resolution": 0,
|
||||
"resolution": 1,
|
||||
"master_sha": null,
|
||||
"because_sha": "da1c49171d0df185545cfbbd600e287f7c6160fa"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -458,25 +458,24 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
|
|||
return;
|
||||
|
||||
case GLSLstd450Tanh: {
|
||||
/* tanh(x) := (0.5 * (e^x - e^(-x))) / (0.5 * (e^x + e^(-x)))
|
||||
/* tanh(x) := (e^x - e^(-x)) / (e^x + e^(-x))
|
||||
*
|
||||
* With a little algebra this reduces to (e^2x - 1) / (e^2x + 1)
|
||||
* We clamp x to [-10, +10] to avoid precision problems. When x > 10,
|
||||
* e^x dominates the sum, e^(-x) is lost and tanh(x) is 1.0 for 32 bit
|
||||
* floating point.
|
||||
*
|
||||
* We clamp x to (-inf, +10] to avoid precision problems. When x > 10,
|
||||
* e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the
|
||||
* computation e^2x +/- 1 so it can be ignored.
|
||||
*
|
||||
* For 16-bit precision we clamp x to (-inf, +4.2] since the maximum
|
||||
* representable number is only 65,504 and e^(2*6) exceeds that. Also,
|
||||
* if x > 4.2, tanh(x) will return 1.0 in fp16.
|
||||
* For 16-bit precision this we clamp x to [-4.2, +4.2].
|
||||
*/
|
||||
const uint32_t bit_size = src[0]->bit_size;
|
||||
const double clamped_x = bit_size > 16 ? 10.0 : 4.2;
|
||||
nir_ssa_def *x = nir_fmin(nb, src[0],
|
||||
nir_imm_floatN_t(nb, clamped_x, bit_size));
|
||||
nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0));
|
||||
val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0),
|
||||
nir_fadd_imm(nb, exp2x, 1.0));
|
||||
nir_ssa_def *x = nir_fclamp(nb, src[0],
|
||||
nir_imm_floatN_t(nb, -clamped_x, bit_size),
|
||||
nir_imm_floatN_t(nb, clamped_x, bit_size));
|
||||
val->ssa->def =
|
||||
nir_fdiv(nb, nir_fsub(nb, build_exp(nb, x),
|
||||
build_exp(nb, nir_fneg(nb, x))),
|
||||
nir_fadd(nb, build_exp(nb, x),
|
||||
build_exp(nb, nir_fneg(nb, x))));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue