Skip to content

Faster Logistic Function

An answer to this question on the Scientific Computing Stack Exchange.

Question

I've noticed that a fairly significant number of cycles in one of my programs are being consumed by the logistic function: $$f(x)=\frac{1}{1+e^{-x}}$$ Is there a good approximation I can use to reduce the cost of this function?

Answer

Yes! There are nice approximations of the logistic.

Plot of Approximating Functions

As shown below, several functions approximate the logistic (shown as blue dots). This graph is available interactively here. Plot of various approximations of the logistic function

Benchmark Results of Approximating Functions

To explore this better, I've benchmarked all of the function using code at the bottom of this post. The results of that are as follows:

name      rms_error        maxdiff        time_us        speedup        samples
                      logistic_with_tanh     5.9496e-02     1.5014e-01         0.0393         0.5076      200000001
                      logistic_with_atan     3.9051e-02     9.6934e-02         0.0321         0.6211      200000001
                       logistic_with_erf     6.5068e-02     1.6581e-01         0.0299         0.6676      200000001
            logistic_fexp_quintic_approx     1.2921e-07     5.9050e-07         0.0246         0.8118      200000001
        logistic_product_approx_float128     8.7032e-04     1.7217e-03         0.0209         0.9523      200000001
           logistic_with_exp_no_overflow     4.7660e-17     1.6653e-16         0.0198         1.0084      200000001
              logistic_product_approx128     8.7032e-04     1.7211e-03         0.0164         1.2187      200000001
         log_w_approx_exp_no_overflow128     8.7193e-04     1.7211e-03         0.0158         1.2640      200000001
                      logistic_with_sqrt     8.3414e-02     1.1086e-01         0.0146         1.3662      200000001
          log_w_approx_exp_no_overflow16     6.9726e-03     1.4074e-02         0.0141         1.4114      200000001
  log_w_approx_exp_no_overflow16_clamped     6.9726e-03     1.4074e-02         0.0141         1.4153      200000001
             logistic_schraudolph_approx     1.5661e-03     8.9906e-03         0.0138         1.4497      200000001
                       logistic_with_abs     6.0968e-02     8.2289e-02         0.0134         1.4936      200000001
                           logistic_orig     0.0000e+00     0.0000e+00         0.0199         ------      200000001

Discussion of Approximating Functions

Let's talk about the fastest few approximations.

logistic_with_abs

This is the fastest, but least accurate function and is given by $$f(x)=\frac{1}{2}\left(1+\frac{x}{1+|x|}\right)$$ It is 1.5x faster than the exact logistic with an RMS error of $6\cdot10^{-2}$ in the range $[-10,10]$.

The Schraudolph Approximation

Is drawn from this paper and relies on dark magic involving the IEEE754 definitions of floating-point numbers. From the paper:

After multiplication, the fractional part of y will spill over into the highest-order bits of the mantissa $m$. This spillover is not only harmless, but in fact is highly desirable—under the IEEE-754 format, it amounts to a linear interpolation between neighboring integer exponents. The technique therefore exponentiates real-valued arguments as well as a lookup table with $2^{11}$ entries and linear interpolation.

Graph of the Schraudolph approximation at different scales

This function has a relatively low RMS of $1.6\cdot10^{-3}$ and a 1.4x speed-up versus the exact logistic.

log_w_approx_exp_no_overflow16

This function relies on the approximation $$f(x)=\frac{1}{1+\left(1-\frac{x}{n}\right)^{n}}$$ where we've used $n=16$. The value of $n=16$ is that it boils down to very simple assembly code giving a 1.4x speed-up with a decent RMS of $7.0\cdot10^{-3}$. Increasing $n$ would give a better approximation at the cost of worse performance.

Passing the code

double exp_product_approx16(double x){
  x = 1 + x / 16;
  const auto a = x * x;
  const auto b = a * a;
  const auto c = b * b;
  const auto d = c * c;
  return d;
}
inline double log_w_approx_exp_no_overflow16(double x){
  return 1/(1+exp_product_approx16(-x));
}

through Godbolt we get

.LCPI0_0:
        .quad   0x3fb0000000000000              # double 0.0625
.LCPI0_1:
        .quad   0x3ff0000000000000              # double 1
exp_product_approx16(double):              # @exp_product_approx16(double)
        mulsd   xmm0, qword ptr [rip + .LCPI0_0]
        addsd   xmm0, qword ptr [rip + .LCPI0_1]
        mulsd   xmm0, xmm0
        mulsd   xmm0, xmm0
        mulsd   xmm0, xmm0
        mulsd   xmm0, xmm0
        ret

where we see that each doubling of $n$ introduces only a single additional mulsd instruction.

log_w_approx_exp_no_overflow16_clamped

This function appears to have nearly identical characteristics to log_w_approx_exp_no_overflow16, but is defined as $$f(x)=\begin{cases} 1 & x\ge n \ \frac{1}{1+\left(1-\frac{x}{n}\right)^{n}} & \textrm{otherwise} \end{cases}$$ if you squint at the function, you find that it increases from $(-\infty,n)$ and then decreases from $(n,\infty)$: Plot of a power approximation of the logistic function This is fine if our inputs are all in the range $(-\infty,n]$, but if the inputs are larger than this it becomes problematic. We can handle this either by increasing $n$ (remember that doubling it only adds a single assembly instruction) or by introducing an unlikely if. We choose to use the if in case the input is adversarial and do so with very minor performance costs:

constexpr double exp_product_approx16(double x){
  x = 1 + x / 16;
  const auto a = x * x;
  const auto b = a * a;
  const auto c = b * b;
  const auto d = c * c;
  return d;
}
inline double log_w_approx_exp_no_overflow16_clamped(double x){
  if(x >= 16) [[unlikely]] {
      return 1;
  }
  return 1/(1+exp_product_approx16(-x));
}

Note that the way this code is written is important! If we were to write $x^{16}$ as:

x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x

the compiler would generate 16 multiplication instructions! The -ffast-math flag avoids this, but at the cost of potentially reducing the accuracy of math throughout your program.

Similarly, using

std::pow(x, 16)

will generate a call to the std::pow function, which will be slower than the doubling method we've used above.

Recommendation

I recommend you benchmark the approximations on your own system and choose a function with the quality/performance trade-off that works best for you. Personally, I've found log_w_approx_exp_no_overflow16_clamped to be sufficient for my needs and prefer it to the black magic of the Schraudolph approximation. This is especially so since the accuracy of the Schraudolph approximation is fixed while adjusting $n$ allows me to easily tune my accuracy/performance trade-off (see, for example, log_w_approx_exp_no_overflow128).

Benchmarking Code

// Compile with: clang++ -O3 test.cpp
// Functions plotted here: https://www.desmos.com/calculator/nkblxiypxh
#include <chrono>
#include <cmath>
#include <endian.h>
#include <iomanip>
#include <iostream>
#include <string>
constexpr double STEP_SIZE = 0.0000001;
// constexpr double STEP_SIZE = 0.001;
constexpr int NAME_LEN = 40;
constexpr double logistic_orig(double x) {
  return x > 0 ? (1.0 / (1.0 + std::exp(-x))) : (1.0 - 1.0 / (1.0 + std::exp(x)));
}
// Based on "A Fast, Compact Approximation of the Exponential Function"
// By Nicol N. Schraudolph
// https://nic.schraudolph.org/pubs/Schraudolph99.pdf
constexpr inline double exp_schraudolph_approx(double x){
  constexpr auto c2_to_the_20th = 1 << 20;
  constexpr auto ln2 = 0.6931471805599453;
  constexpr auto a = c2_to_the_20th / ln2;
  constexpr auto b = 1023 * c2_to_the_20th;
  constexpr auto c = 60801;
  union {
    double d;
// NOTE: This works for a little-endian architecture
#if __BYTE_ORDER == __LITTLE_ENDIAN
    struct {
      int j, i;
    } n;
#else
    struct {
      int i, j;
    } n;
#endif
  } eco = {};
  // Black magic happens here. Read the paper.
  eco.n.i = a * x + (b - c);
  return eco.d;
}
constexpr double logistic_schraudolph_approx(double x) {
  return x > 0 ? (1.0 / (1.0 + exp_schraudolph_approx(-x)))
               : (1.0 - 1.0 / (1.0 + exp_schraudolph_approx(x)));
}
constexpr double exp_product_approx128(double x){
  x = 1 + x / 128;
  const auto a = x * x;
  const auto b = a * a;
  const auto c = b * b;
  const auto d = c * c;
  const auto e = d * d;
  const auto f = e * e;
  const auto g = f * f;
  return g;
}
constexpr double exp_product_approx16(double x){
  x = 1 + x / 16;
  const auto a = x * x;
  const auto b = a * a;
  const auto c = b * b;
  const auto d = c * c;
  return d;
}
constexpr double logistic_product_approx128(double x) {
  return x > 0 ? (1.0 / (1.0 + exp_product_approx128(-x)))
               : (1.0 - 1.0 / (1.0 + exp_product_approx128(x)));
}
constexpr float exp_product_approx_float128(float x){
  x = 1 + x / 128;
  const auto a = x * x;
  const auto b = a * a;
  const auto c = b * b;
  const auto d = c * c;
  const auto e = d * d;
  const auto f = e * e;
  const auto g = f * f;
  return g;
}
constexpr float logistic_product_approx_float128(float x) {
  return x > 0 ? (1.0 / (1.0 + exp_product_approx_float128(-x)))
               : (1.0 - 1.0 / (1.0 + exp_product_approx_float128(x)));
}
inline double fexp_quintic(double x){
  constexpr int64_t mantissa = static_cast<int64_t>(1)<<52;
  constexpr int64_t bias = 1023;
  constexpr int64_t ishift = mantissa*bias;
  constexpr double ln2 = 0.6931471805599453;
  constexpr double s1 = -1.90188191959304e-3;
  constexpr double s2 = -9.01146535969578e-3;
  constexpr double s3 = -5.57129652016652e-2;
  constexpr double s4 = -2.40226506959101e-1;
  constexpr double s5 =  3.06852819440055e-1;
  const double y = x/ln2;
  const double yf = y-std::floor(y);
  const double y2 = y-((((s1*yf+s2)*yf+s3)*yf+s4)*yf+s5)*yf;
  const int64_t i8 = mantissa*y2+ishift;
  return *reinterpret_cast<const double*>(&i8);
}
inline double logistic_fexp_quintic_approx(double x) {
  return x > 0 ? (1.0 / (1.0 + fexp_quintic(-x)))
               : (1.0 - 1.0 / (1.0 + fexp_quintic(x)));
}
inline double logistic_with_abs(double x){
  return 0.5*(1+x/(1+std::abs(x)));
}
inline double logistic_with_tanh(double x){
  return 0.5*(1+std::tanh(x));
}
inline double logistic_with_erf(double x){
  return 0.5*(1+std::erf(std::sqrt(M_PI)*x/2));
}
inline double logistic_with_sqrt(double x){
  const auto temp = 0.5/std::sqrt(1+x*x);
  return (x<0)?temp:1-temp;
}
inline double logistic_with_atan(double x){
  return 0.5*(1+std::atan(M_PI*x/2)*2/M_PI);
}
inline double logistic_with_exp_no_overflow(double x){
  return 1/(1+std::exp(-x));
}
inline double log_w_approx_exp_no_overflow128(double x){
  return 1/(1+exp_product_approx128(-x));
}
inline double log_w_approx_exp_no_overflow16(double x){
  return 1/(1+exp_product_approx16(-x));
}
inline double log_w_approx_exp_no_overflow16_clamped(double x){
  // If you squint at exp_product_approx16 you realize it reaches
  // y=1 at x=16 and must decrease thereafter, so we use that
  // as the clamp value here.
  if(x >= 16) [[unlikely]] {
      return 1;
  }
  return 1/(1+exp_product_approx16(-x));
}
template<typename Func>
double time_it(Func func, const std::string& func_name, const double original_time){
  const auto start = std::chrono::high_resolution_clock::now();
  double diff = 0;
  int count = 0;
  double maxdiff = -std::numeric_limits<double>::infinity();
  for(double x=-10;x<10;x+=STEP_SIZE){
      const auto origval = logistic_orig(x);
      const auto newval = func(x);
      diff += std::pow(origval - newval, 2);
      maxdiff = std::max(maxdiff, std::abs(origval - newval));
      count++;
  }
  const auto end = std::chrono::high_resolution_clock::now();
  const auto time_per_sample = std::chrono::duration_cast<std::chrono::microseconds>(end-start).count()/static_cast<double>(count);
  std::cout<<std::setw(NAME_LEN)<<func_name
           <<std::scientific<<std::setprecision(4)<<std::setw(15)<<std::sqrt(diff/count)
           <<std::scientific<<std::setprecision(4)<<std::setw(15)<<maxdiff
           <<std::fixed<<std::setw(15)<<time_per_sample
           <<std::fixed<<std::setw(15)<<original_time/time_per_sample
           <<std::setw(15)<<count
           <<std::endl;
  return time_per_sample;
}
int main(){
  std::cout<<std::setw(NAME_LEN)<<"name"
           <<std::setw(15)<<"rms_error"
           <<std::setw(15)<<"maxdiff"
           <<std::setw(15)<<"time_us"
           <<std::setw(15)<<"speedup"
           <<std::setw(15)<<"samples"
           <<std::endl;
  double original_speed = 1;
  original_speed = time_it(logistic_orig, "logistic_orig", original_speed);
  time_it(logistic_schraudolph_approx, "logistic_schraudolph_approx", original_speed);
  time_it(logistic_product_approx128, "logistic_product_approx128", original_speed);
  time_it(logistic_fexp_quintic_approx, "logistic_fexp_quintic_approx", original_speed);
  time_it(logistic_product_approx_float128, "logistic_product_approx_float128", original_speed);
  time_it(logistic_with_abs, "logistic_with_abs", original_speed);
  time_it(logistic_with_tanh, "logistic_with_tanh", original_speed);
  time_it(logistic_with_erf, "logistic_with_erf", original_speed);
  time_it(logistic_with_sqrt, "logistic_with_sqrt", original_speed);
  time_it(logistic_with_atan, "logistic_with_atan", original_speed);
  time_it(logistic_with_exp_no_overflow, "logistic_with_exp_no_overflow", original_speed);
  time_it(log_w_approx_exp_no_overflow128, "log_w_approx_exp_no_overflow128", original_speed);
  time_it(log_w_approx_exp_no_overflow16, "log_w_approx_exp_no_overflow16", original_speed);
  time_it(log_w_approx_exp_no_overflow16_clamped, "log_w_approx_exp_no_overflow16_clamped", original_speed);
  return 0;
}