简体   繁体   English

在 10 条或更少的指令中实现 tanh(x) 的最佳非三角浮点近似

[英]Best non-trigonometric floating point approximation of tanh(x) in 10 instructions or less

Description描述

I need a reasonably accurate fast hyperbolic tangent for a machine that has no built-in floating point trigonometry, so eg the usual tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) formula is going to need an approximation of exp(2x) .对于没有内置浮点三角函数的机器,我需要一个相当准确的快速双曲正切,例如,通常的tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)公式正在运行需要exp(2x)的近似值。
All other instructions like addition, subtraction, multiplication, division, and even FMA (= MUL+ADD in 1 op) are present.所有其他指令,如加法、减法、乘法、除法,甚至 FMA(= MUL+ADD in 1 op)都存在。

Right now I have several approximations, but none of them are satisfactory in terms of accuracy.现在我有几个近似值,但在准确性方面没有一个是令人满意的。

[Update from the comments:] [评论更新:]

  • The instruction for trunc() / floor() is available trunc() / floor()的指令可用
  • There is a way to transparently reinterpret floats as integers and do all kinds of bit ops有一种方法可以透明地将浮点数重新解释为整数并执行各种位操作
  • There is a family of instructions called SEL.xx (.GT, .LE, etc.) which compare 2 values and choose what to write to the destination有一系列指令称为 SEL.xx(.GT、.LE 等),它们比较 2 个值并选择写入目标的内容
  • DIVs are twice as slow, so nothing exceptional, DIVs are okay to use DIV 的速度是原来的两倍,所以没什么特别的,DIV 可以使用

Approach 1方法一

Accuracy: ±1.2% absolute error, see here .准确度:±1.2% 绝对误差, 见这里

Pseudocode (A = accumulator register, T = temporary register):伪代码(A = 累加器寄存器,T = 临时寄存器):

[1] FMA T, 36.f / 73.f, A, A   // T := 36/73 + X^2
[2] MUL A, A, T                // A := X(36/73 + X^2)
[3] ABS T, A                   // T := |X(36/73 + X^2)|
[4] ADD T, T, 32.f / 73.f      // T := |X(36/73 + X^2)| + 32/73
[5] DIV A, A, T                // A := X(36/73 + X^2) / (|X(36/73 + X^2)| + 32/73)

Approach 2方法二

Accuracy: ±0.9% absolute error, see here .准确度:±0.9% 绝对误差, 见这里

Pseudocode (A = accumulator register, T = temporary register):伪代码(A = 累加器寄存器,T = 临时寄存器):

[1] FMA T, 3.125f, A, A        // T := 3.125 + X^2
[2] DIV T, 25.125f, T          // T := 25.125/(3.125 + X^2)
[3] MUL A, A, 0.1073f          // A := 0.1073*X
[4] FMA A, A, A, T             // A := 0.1073*X + 0.1073*X*25.125/(3.125 + X^2)
[5] MIN A, A, 1.f              // A := min(0.1073*X + 0.1073*X*25.125/(3.125 + X^2), 1)
[6] MAX A, A, -1.f             // A := max(min(0.1073*X + 0.1073*X*25.125/(3.125 + X^2), 1), -1)

Approach 3方法 3

Accuracy: ±0.13% absolute error, see here .准确度:±0.13% 绝对误差, 见这里

Pseudocode (A = accumulator register, T = temporary register):伪代码(A = 累加器寄存器,T = 临时寄存器):

[1] FMA T, 14.f, A, A          // T := 14 + X^2
[2] FMA T, -133.f, T, T        // T := (14 + X^2)^2 - 133
[3] DIV T, A, T                // T := X/((14 + X^2)^2 - 133)
[4] FMA A, 52.5f, A, A         // A := 52.5 + X^2
[5] MUL A, A, RSQRT(15.f)      // A := (52.5 + X^2)/sqrt(15)
[6] FMA A, -120.75f, A, A      // A := (52.5 + X^2)^2/15 - 120.75
[7] MUL A, A, T                // A := ((52.5 + X^2)^2/15 - 120.75)*X/((14 + X^2)^2 - 133)
[8] MIN A, A, 1.f              // A := min(((52.5 + X^2)^2/15 - 120.75)*X/((14 + X^2)^2 - 133), 1)
[9] MAX A, A, -1.f             // A := max(min(((52.5 + X^2)^2/15 - 120.75)*X/((14 + X^2)^2 - 133), 1), -1)

The question问题

Is there anything better that can possibly fit in 10 non-trigonometric float32 instructions?有没有更好的东西可以适合 10 个非三角 float32 指令?

After doing much exploratory work, I came to the conclusion that approach 2 is the most promising direction.经过大量探索性工作,我得出结论,方法 2 是最有希望的方向。 Since division is very fast on the asker's platform, rational approximations are attractive.由于在提问者的平台上除法非常快,因此有理近似很有吸引力。 The platform's support for FMA should be exploited aggressively.应该积极利用平台对 FMA 的支持。 Below I am showing C code that implements a fast tanhf() in seven operations and achieves maximum absolute error of less than 3.3e-3.下面我展示了 C 代码,它在七次操作中实现了一个快速的tanhf()并实现了小于 3.3e-3 的最大绝对误差。

I used the Remez algorithm to compute the coefficients for the rational approximation and used a heuristic search to reduce these coefficients to as few bits as feasible, which may benefit some processor architectures that are able to incorporate floating-point data into an immediate field of commonly used floating-point instructions.我使用 Remez 算法计算有理逼近的系数,并使用启发式搜索将这些系数减少到尽可能少的位,这可能有利于一些能够将浮点数据合并到通常的立即域中的处理器架构使用浮点指令。

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

/* Fast computation of hyperbolic tangent. Rational approximation with clamping.
   Maximum absolute errror = 3.20235857e-3 @ +/-3.21770620
*/
float fast_tanhf_rat (float x)
{
    const float n0 = -8.69873047e-1f; // -0x1.bd6000p-1
    const float n1 = -8.78143311e-3f; // -0x1.1fc000p-7
    const float d0 =  2.72656250e+0f; //  0x1.5d0000p+1
    float x2 = x * x;
    float num = fmaf (n0, x2, n1);
    float den = x2 + d0;
    float quot = num / den;
    float res = fmaf (quot, x, x);
    res = fminf (fmaxf (res, -1.0f), 1.0f);
    return res;
}

int main (void)
{
    double ref, err, maxerr = 0;
    float arg, res, maxerrloc = INFINITY;
    maxerr = 0;
    arg = 0.0f;
    while (arg < 0x1.0p64f) {
        res = fast_tanhf_rat (arg);
        ref = tanh ((double)arg);
        err = fabs ((double)res - ref);
        if (err > maxerr) {
            maxerr = err;
            maxerrloc = arg;
        }
        arg = nextafterf (arg, INFINITY);
    }
    arg = -0.0f;
    while (arg > -0x1.0p64f) {
        res = fast_tanhf_rat (arg);
        ref = tanh ((double)arg);
        err = fabs ((double)res - ref);
        if (err > maxerr) {
            maxerr = err;
            maxerrloc = arg;
        }
        arg = nextafterf (arg, -INFINITY);
    }
    printf ("maximum absolute error = %15.8e @ %15.8e\n", maxerr, maxerrloc);
    return EXIT_SUCCESS;
}

Given that asker budgeted for up to ten operations, we can increase the degree of both numerator and denominator polynomials by one to achieve a fast tanhf() implementation comprising nine operations that has significantly lower maximum absolute error:考虑到 asker 预算多达 10 次操作,我们可以将分子和分母多项式的次数都增加 1 以实现快速tanhf()实现,该实现包括九个操作,最大绝对误差显着降低:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

/* Fast computation of hyperbolic tangent. Rational approximation with clamping.
   Maximum absolute error = 7.01054378e-5 @ +/-2.03603077
 */
float fast_tanhf_rat2 (float x)
{
    const float n0 = -9.48005676e-1f; // -0x1.e56100p-1
    const float n1 = -2.61142578e+1f; // -0x1.a1d400p+4
    const float n2 = -2.33942270e-3f; // -0x1.32a200p-9
    const float d0 =  3.41303711e+1f; //  0x1.110b00p+5
    const float d1 =  7.84101563e+1f; //  0x1.39a400p+6
    float x2 = x * x;
    float num = fmaf (fmaf (n0, x2, n1), x2, n2);
    float den = fmaf (x2 + d0, x2, d1);
    float quot = num / den;
    float res = fmaf (quot, x, x);
    res = fminf (fmaxf (res, -1.0f), 1.0f);
    return res;
}

int main (void)
{
    double ref, err, maxerr = 0;
    float arg, res, maxerrloc = INFINITY;
    maxerr = 0;
    arg = 0.0f;
    while (arg < 0x1.0p32f) {
        res = fast_tanhf_rat2 (arg);
        ref = tanh ((double)arg);
        err = fabs ((double)res - ref);
        if (err > maxerr) {
            maxerr = err;
            maxerrloc = arg;
        }
        arg = nextafterf (arg, INFINITY);
    }
    arg = -0.0f;
    while (arg > -0x1.0p32f) {
        res = fast_tanhf_rat2 (arg);
        ref = tanh ((double)arg);
        err = fabs ((double)res - ref);
        if (err > maxerr) {
            maxerr = err;
            maxerrloc = arg;
        }
        arg = nextafterf (arg, -INFINITY);
    }
    printf ("maximum absolute error = %15.8e @ %15.8e\n", maxerr, maxerrloc);
    return EXIT_SUCCESS;
}

Nic Schraudolph , author of the paper describing the exponential approximation that the previous version of this answer uses, suggests the following. Nic Schraudolph是描述该答案的先前版本使用的指数近似的论文的作者,他提出了以下建议。 It has error 0.5%.它的误差为 0.5%。

Java implementation (for portable bit munging): Java 实现(用于便携式钻头):

public class Tanh {
  private static final float m = (float)((1 << 23) / Math.log(2));
  private static final int b = Float.floatToRawIntBits(1);

  private static float tanh(float x) {
    int y = (int)(m * x);
    float exp_x = Float.intBitsToFloat(b + y);
    float exp_minus_x = Float.intBitsToFloat(b - y);
    return (exp_x - exp_minus_x) / (exp_x + exp_minus_x);
  }

  public static void main(String[] args) {
    double error = 0;
    int end = Float.floatToRawIntBits(10);
    for (int i = 0; i <= end; i++) {
      float x = Float.intBitsToFloat(i);
      error = Math.max(error, Math.abs(tanh(x) - Math.tanh(x)));
    }
    System.out.println(error);
  }
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM