'Fastest Implementation of the Natural Exponential Function Using SSE

I'm looking for an approximation of the natural exponential function operating on SSE element. Namely - __m128 exp( __m128 x ).

I have an implementation which is quick but seems to be very low in accuracy:

static inline __m128 FastExpSse(__m128 x)
{
    __m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
    __m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
    __m128  m87 = _mm_set1_ps(-87);
    // fast exponential function, x should be in [-87, 87]
    __m128 mask = _mm_cmpge_ps(x, m87);

    __m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
    return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}

Could anybody have an implementation with better accuracy yet as fast (Or faster)?

I'd be happy if it is written in C Style.

Thank You.



Solution 1:[1]

The C code below is a translation into SSE intrinsics of an algorithm I used in a previous answer to a similar question.

The basic idea is to transform the computation of the standard exponential function into computation of a power of 2: expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504). We split t = x * 1.44269504 into an integer i and a fraction f, such that t = i + f and 0 <= f <= 1. We can now compute 2f with a polynomial approximation, then scale the result by 2i by adding i to the exponent field of the single-precision floating-point result.

One problem that exists with an SSE implementation is that we want to compute i = floorf (t), but there is no fast way to compute the floor() function. However, we observe that for positive numbers, floor(x) == trunc(x), and that for negative numbers, floor(x) == trunc(x) - 1, except when x is a negative integer. However, since the core approximation can handle an f value of 1.0f, using the approximation for negative arguments is harmless. SSE provides an instruction to convert single-precision floating point operands to integers with truncation, so this solution is efficient.

Peter Cordes points out that SSE4.1 supports a fast floor function _mm_floor_ps(), so a variant using SSE4.1 is also shown below. Not all toolchains automatically predefine the macro __SSE4_1__ when SSE 4.1 code generation is enabled, but gcc does.

Compiler Explorer (Godbolt) shows that gcc 7.2 compiles the code below into sixteen instructions for plain SSE and twelve instructions for SSE 4.1.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif

/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, e, p, r;
    __m128i i, j;
    __m128 l2e = _mm_set1_ps (1.442695041f);  /* log2(e) */
    __m128 c0  = _mm_set1_ps (0.3371894346f);
    __m128 c1  = _mm_set1_ps (0.657636276f);
    __m128 c2  = _mm_set1_ps (1.00172476f);

    /* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */   
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
#ifdef __SSE4_1__
    e = _mm_floor_ps (t);                /* floor(t) */
    i = _mm_cvtps_epi32 (e);             /* (int)floor(t) */
#else /* __SSE4_1__*/
    i = _mm_cvttps_epi32 (t);            /* i = (int)t */
    j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
    i = _mm_sub_epi32 (i, j);            /* (int)t - signbit(t) */
    e = _mm_cvtepi32_ps (i);             /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
    f = _mm_sub_ps (t, e);               /* f = t - floor(t) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

int main (void)
{
    union {
        float f[4];
        unsigned int i[4];
    } arg, res;
    double relerr, maxrelerr = 0.0;
    int i, j;
    __m128 x, y;

    float start[2] = {-0.0f, 0.0f};
    float finish[2] = {-87.33654f, 88.72283f};

    for (i = 0; i < 2; i++) {

        arg.f[0] = start[i];
        arg.i[1] = arg.i[0] + 1;
        arg.i[2] = arg.i[0] + 2;
        arg.i[3] = arg.i[0] + 3;
        do {
            memcpy (&x, &arg, sizeof(x));
            y = fast_exp_sse (x);
            memcpy (&res, &y, sizeof(y));
            for (j = 0; j < 4; j++) {
                double ref = exp ((double)arg.f[j]);
                relerr = fabs ((res.f[j] - ref) / ref);
                if (relerr > maxrelerr) {
                    printf ("arg=% 15.8e  res=%15.8e  ref=%15.8e  err=%15.8e\n", 
                            arg.f[j], res.f[j], ref, relerr);
                    maxrelerr = relerr;
                }
            }   
            arg.i[0] += 4;
            arg.i[1] += 4;
            arg.i[2] += 4;
            arg.i[3] += 4;
        } while (fabsf (arg.f[3]) < fabsf (finish[i]));
    }
    printf ("maximum relative errror = %15.8e\n", maxrelerr);
    return EXIT_SUCCESS;
}

An alternative design for fast_sse_exp() extracts the integer portion of the adjusted argument x / log(2) in round-to-nearest mode, using the well-known technique of adding the "magic" conversion constant 1.5 * 223 to force rounding in the correct bit position, then subtracting out the same number again. This requires that the SSE rounding mode in effect during the addition is "round to nearest or even", which is the default. wim pointed out in comments that some compilers may optimize out the addition and subtraction of the conversion constant cvt as redundant when aggressive optimization is used, interfering with the functionality of this code sequence, so it is recommended to inspect the machine code generated. The approximation interval for computation of 2f is now centered around zero, since -0.5 <= f <= 0.5, requiring a different core approximation.

/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, p, r;
    __m128i i, j;

    const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
    const __m128 cvt = _mm_set1_ps (12582912.0f);  /* 1.5 * (1 << 23) */
    const __m128 c0 =  _mm_set1_ps (0.238428936f);
    const __m128 c1 =  _mm_set1_ps (0.703448006f);
    const __m128 c2 =  _mm_set1_ps (1.000443142f);

    /* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
    r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
    f = _mm_sub_ps (t, r);               /* f = t - rint (t) */
    i = _mm_cvtps_epi32 (t);             /* i = (int)t */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

The algorithm for the code in the question appears to be taken from the work of Nicol N. Schraudolph, which cleverly exploits the semi-logarithmic nature of IEEE-754 binary floating-point formats:

N. N. Schraudolph. "A fast, compact approximation of the exponential function." Neural Computation, 11(4), May 1999, pp.853-862.

After removal of the argument clamping code, it reduces to just three SSE instructions. The "magical" correction constant 486411 is not optimal for minimizing maximum relative error over the entire input domain. Based on simple binary search, the value 298765 seems to be superior, reducing maximum relative error for FastExpSse() to 3.56e-2 vs. maximum relative error of 1.73e-3 for fast_exp_sse().

/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
    __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    __m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
    __m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
    return _mm_castsi128_ps (t);
}

Schraudolph's algorithm basically uses the linear approximation 2f ~= 1.0 + f for f in [0,1], and its accuracy could be improved by adding a quadratic term. The clever part of Schraudolph's approach is computing 2i * 2f without explicitly separating the integer portion i = floor(x * 1.44269504) from the fraction. I see no way to extend that trick to a quadratic approximation, but one can certainly combine the floor() computation from Schraudolph with the quadratic approximation used above:

/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 f, p, r;
    __m128i t, j;
    const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
    const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
    const __m128 c0 = _mm_set1_ps (0.3371894346f);
    const __m128 c1 = _mm_set1_ps (0.657636276f);
    const __m128 c2 = _mm_set1_ps (1.00172476f);

    t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
    j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
    t = _mm_sub_epi32 (t, j);
    f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

Solution 2:[2]

A good increase in accuracy in my algorithm (implementation FastExpSse in the answer above) can be obtained at the cost of an integer subtraction and floating-point division by using FastExpSse(x/2)/FastExpSse(-x/2) instead of FastExpSse(x). The trick here is to set the shift parameter (298765 above) to zero so that the piecewise linear approximations in the numerator and denominator line up to give you substantial error cancellation. Roll it into a single function:

__m128 BetterFastExpSse (__m128 x)
{
  const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2));  // to get exp(x/2)
  const __m128i b = _mm_set1_epi32 (127 * (1 << 23));       // NB: zero shift!
  __m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
  __m128i s = _mm_add_epi32 (b, r);
  __m128i t = _mm_sub_epi32 (b, r);
  return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}

(I'm not a hardware guy - how bad a performance killer is the division here?)

If you need exp(x) just to get y = tanh(x) (e.g. for neural networks), use FastExpSse with zero shift as follows:

a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);

to get the same type of error cancellation benefit. The logistic function works similarly, using FastExpSse(x/2)/(FastExpSse(x/2) + FastExpSse(-x/2)) with zero shift. (This is just to show the principle - you obviously don't want to evaluate FastExpSse multiple times here, but roll it into a single function along the lines of BetterFastExpSse above.)

I did develop a series of higher-order approximations from this, ever more accurate but also slower. Unpublished but happy to collaborate if anyone wants to give them a spin.

And finally, for some fun: use in reverse gear to get FastLogSse. Chaining that with FastExpSse gives you both operator and error cancellation, and out pops a blazingly fast power function...

Solution 3:[3]

Going back through my notes from way back then, I did explore ways to improve the accuracy without using division. I used the same reinterpret-as-float trick but applied a polynomial correction to the mantissa which was essentially calculated in 16-bit fixed-point arithmetic (the only way to do it fast back then).

The cubic resp. quartic versions give you 4 resp. 5 significant digits of accuracy. There was no point increasing the order beyond that, as the noise of the low-precision arithmetic then starts to drown out the error of the polynomial approximation. Here are the plain C versions:

#include <stdint.h>

float fastExp3(register float x)  // cubic spline approximation
{
    union { float f; int32_t i; } reinterpreter;

    reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
    int32_t m = (reinterpreter.i >> 7) & 0xFFFF;  // copy mantissa
    // empirical values for small maximum relative error (8.34e-5):
    reinterpreter.i +=
         ((((((((1277*m) >> 14) + 14825)*m) >> 14) - 79749)*m) >> 11) - 626;
    return reinterpreter.f;
}

float fastExp4(register float x)  // quartic spline approximation
{
    union { float f; int32_t i; } reinterpreter;

    reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
    int32_t m = (reinterpreter.i >> 7) & 0xFFFF;  // copy mantissa
    // empirical values for small maximum relative error (1.21e-5):
    reinterpreter.i += (((((((((((3537*m) >> 16)
        + 13668)*m) >> 18) + 15817)*m) >> 14) - 80470)*m) >> 11);
    return reinterpreter.f;
}

The quartic one obeys (fastExp4(0f) == 1f), which can be important for fixed-point iteration algorithms.

How efficient are these integer multiply-shift-add sequences in SSE? On architectures where float arithmetic is just as fast, one could use that instead, reducing the arithmetic noise. This would essentially yield cubic and quartic extensions of @njuffa's answer above.

Solution 4:[4]

For softmax use, I'm envisioning the flow as:

auto a = _mm_mul_ps(x, _mm_set1_ps(12102203.2f));
auto b = _mm_castsi128_ps(_mm_cvtps_epi32(a)); // so far as in other variants

// copy 9 MSB from 0x3f800000 over 'b' so that 1 <= c < 2
//  - also  1 <= poly_eval(...) < 2
auto c = replace_exponent(b, _mm_set1_ps(1.0f));
auto d = poly_eval(c, kA, kB, kC);  // 2nd degree polynomial
auto e = replace_exponent(d, b);    // restore exponent : 2^i * 2^f

The exponent copying can be done as bitwise select using a proper mask (AVX-512 has vpternlogd, and I'm using actually Arm Neon vbsl).

All the input values x must be negative and clamped between -17-f(N) <= x <= -f(N), so that when scaled by (1<<23)/log(2), the maximum sum of the N resulting floating point values do not reach infinity and that the reciprocal does not become denormal. For N=3, f(N) = 4. Larger f(N) will trade off input precision.

The polyeval coefficients are generated for example by polyfit([1 1.5 2],[1 sqrt(2) 2]), with kA=0.343146, kB=-0.029437, kC=0.68292, producing strictly values smaller than 2 and preventing discontinuities. The maximum average error can be diminished by evaluating the polynomial at x=[1+max_err 1.5-eps 2], y=[1 2^(.5-eps) 2-max_err].

For strictly SSE/AVX, exponent replacement for 1.0f can be done by (x & 0x007fffff) | 0x3f800000). A two instruction sequence for the latter exponent replacement can be found by ensuring that poly_eval(x) evaluates to a range, which can be directly ored with b & 0xff800000.

Solution 5:[5]

I have developed for my purposes the following function that calculates quickly and accurately the natural exponent with single precision. The function works in the entire range of float values. The code is written under Visual Studio (x86). AVX is used instead of SSE, but that shouldn't be a problem. The accuracy of this function is almost the same as standard expf function, but significantly faster. Used approximation is based on the Chebyshev series expansion of the function f(t)=t/(2^(t/2)-1)+t/2 for t from the [-1; 1]. I thank Peter Cordes for his good advice.

_declspec(naked) float _vectorcall fexp(float x)
{
  static const float ct[7] =       // Constants table
  {
    1.44269502f,                   // lb(e)
    1.92596299E-8f,                // Correction to the value lb(e)
    -9.21120925E-4f,               // 16*b2
    0.115524396f,                  // 4*b1
    2.88539004f,                   // b0
    2.0f,                          // 2
    4.65661287E-10f                // 2^-31
  };
  _asm
  {
    mov ecx,offset ct              // ecx contains the address of constants tables
    vmulss xmm1,xmm0,[ecx]         // xmm1 = x*lb(e)
    vcvtss2si eax,xmm1             // eax = round(x*lb(e)) = k
    cdq                            // edx=-1, if x<0 or overflow, otherwise edx=0
    vmovss xmm3,[ecx+8]            // Initialize the sum with highest coefficient 16*b2
    and edx,4                      // edx=4, if x<0 or overflow, otherwise edx=0
    vcvtsi2ss xmm1,xmm1,eax        // xmm1 = k
    lea eax,[eax+8*edx]            // Add 32 to exponent, if x<0
    vfmsub231ss xmm1,xmm0,[ecx]    // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
    add eax,126                    // The exponent of 2^(k-1) or 2^(k+31) with bias 127
    jle exp_low                    // Jump if x<<0 or overflow (|x| too large or x=NaN)
    vfmadd132ss xmm0,xmm1,[ecx+4]  // xmm0 = t/2 (corrected value)
    cmp eax,254                    // Check that the exponent is not too large
    jg exp_inf                     // Jump to set Inf if overflow
    vmulss xmm2,xmm0,xmm0          // xmm2 = t^2/4 - the argument of the polynomial
    shl eax,23                     // The bits of the float value 2^(k-1) or 2^(k+31)
    vfmadd213ss xmm3,xmm2,[ecx+12] // xmm3 = 4*b1+4*b2*t^2
    vmovd xmm1,eax                 // xmm1 = 2^(k-1) ??? 2^(k+31)
    vfmsub213ss xmm3,xmm2,xmm0     // xmm3 = -t/2+b1*t^2+b2*t^4
    vaddss xmm0,xmm0,xmm0          // xmm0 = t
    vaddss xmm3,xmm3,[ecx+16]      // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
    vdivss xmm0,xmm0,xmm3          // xmm0 = t/(f(t)-t/2)
    vfmadd213ss xmm0,xmm1,xmm1     // xmm0 = e^x with shifted exponent of -1 or 31
    vmulss xmm0,xmm0,[ecx+edx+20]  // xmm0 = e^x
    ret                            // Return
      exp_low:                     // Handling the case of x<<0 or overflow
    vucomiss xmm0,[ecx]            // Check the sign of x and a condition x=NaN
    jp exp_end                     // Complete with NaN result, if x=NaN
      exp_inf:                     // Entry point for processing large x
    vxorps xmm0,xmm0,xmm0          // xmm0 = 0
    jc exp_end                     // Ready, if x<<0
    vrcpss xmm0,xmm0,xmm0          // xmm0 = Inf in case x>>0
      exp_end:                     // The result at xmm0 is ready
    ret                            // Return
  }
}

Below I post a simplified algorithm. Support for denormalized numbers in the result is removed here.

_declspec(naked) float _vectorcall fexp(float x)
{
  static const float ct[5] =       // Constants table
  {
    1.44269502f,                   // lb(e)
    1.92596299E-8f,                // Correction to the value lb(e)
    -9.21120925E-4f,               // 16*b2
    0.115524396f,                  // 4*b1
    2.88539004f                    // b0
  };
  _asm
  {
    mov edx,offset ct              // edx contains the address of constants tables
    vmulss xmm1,xmm0,[edx]         // xmm1 = x*lb(e)
    vcvtss2si eax,xmm1             // eax = round(x*lb(e)) = k
    vmovss xmm3,[edx+8]            // Initialize the sum with highest coefficient 16*b2
    vcvtsi2ss xmm1,xmm1,eax        // xmm1 = k
    cmp eax,127                    // Check that the exponent is not too large
    jg exp_break                   // Jump to set Inf if overflow
    vfmsub231ss xmm1,xmm0,[edx]    // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
    add eax,127                    // Receive the exponent of 2^k with the bias 127
    jle exp_break                  // The result is 0, if x<<0
    vfmadd132ss xmm0,xmm1,[edx+4]  // xmm0 = t/2 (corrected value)
    vmulss xmm2,xmm0,xmm0          // xmm2 = t^2/4 - the argument of polynomial
    shl eax,23                     // eax contains the bits of 2^k
    vfmadd213ss xmm3,xmm2,[edx+12] // xmm3 = 4*b1+4*b2*t^2
    vmovd xmm1,eax                 // xmm1 = 2^k
    vfmsub213ss xmm3,xmm2,xmm0     // xmm3 = -t/2+b1*t^2+b2*t^4
    vaddss xmm0,xmm0,xmm0          // xmm0 = t
    vaddss xmm3,xmm3,[edx+16]      // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
    vdivss xmm0,xmm0,xmm3          // xmm0 = t/(f(t)-t/2)
    vfmadd213ss xmm0,xmm1,xmm1     // xmm0 = 2^k*(t/(f(t)-t/2)+1) = e^x
    ret                            // Return
      exp_break:                   // Get 0 for x<0 or Inf for x>>0
    vucomiss xmm0,[edx]            // Check the sign of x and a condition x=NaN
    jp exp_end                     // Complete with NaN result, if x=NaN
    vxorps xmm0,xmm0,xmm0          // xmm0 = 0
    jc exp_end                     // Ready, if x<<0
    vrcpss xmm0,xmm0,xmm0          // xmm0 = Inf, if x>>0
      exp_end:                     // The result at xmm0 is ready
    ret                            // Return
  }
}

Solution 6:[6]

There is a paper about creating fast versions of these equations (tanh, cosh, artanh, sinh, etc):

http://ijeais.org/wp-content/uploads/2018/07/IJAER180702.pdf "Creating a Compiler Optimized Inlineable Implementation of Intel Svml Simd Intrinsics"

their tanh equation 6, on page 9 is very similar to @NicSchraudolph answer

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Roland
Solution 2
Solution 3
Solution 4 Aki Suihkonen
Solution 5
Solution 6 Kari