Faster Logistic Function
An answer to this question on the Scientific Computing Stack Exchange.
Question
I've noticed that a fairly significant number of cycles in one of my programs are being consumed by the logistic function: $$f(x)=\frac{1}{1+e^{-x}}$$ Is there a good approximation I can use to reduce the cost of this function?
Answer
Yes! There are nice approximations of the logistic.
Plot of Approximating Functions
As shown below, several functions approximate the logistic (shown as blue dots). This graph is available interactively here.

Benchmark Results of Approximating Functions
To explore this better, I've benchmarked all of the function using code at the bottom of this post. The results of that are as follows:
name rms_error maxdiff time_us speedup samples
logistic_with_tanh 5.9496e-02 1.5014e-01 0.0393 0.5076 200000001
logistic_with_atan 3.9051e-02 9.6934e-02 0.0321 0.6211 200000001
logistic_with_erf 6.5068e-02 1.6581e-01 0.0299 0.6676 200000001
logistic_fexp_quintic_approx 1.2921e-07 5.9050e-07 0.0246 0.8118 200000001
logistic_product_approx_float128 8.7032e-04 1.7217e-03 0.0209 0.9523 200000001
logistic_with_exp_no_overflow 4.7660e-17 1.6653e-16 0.0198 1.0084 200000001
logistic_product_approx128 8.7032e-04 1.7211e-03 0.0164 1.2187 200000001
log_w_approx_exp_no_overflow128 8.7193e-04 1.7211e-03 0.0158 1.2640 200000001
logistic_with_sqrt 8.3414e-02 1.1086e-01 0.0146 1.3662 200000001
log_w_approx_exp_no_overflow16 6.9726e-03 1.4074e-02 0.0141 1.4114 200000001
log_w_approx_exp_no_overflow16_clamped 6.9726e-03 1.4074e-02 0.0141 1.4153 200000001
logistic_schraudolph_approx 1.5661e-03 8.9906e-03 0.0138 1.4497 200000001
logistic_with_abs 6.0968e-02 8.2289e-02 0.0134 1.4936 200000001
logistic_orig 0.0000e+00 0.0000e+00 0.0199 ------ 200000001
Discussion of Approximating Functions
Let's talk about the fastest few approximations.
logistic_with_abs
This is the fastest, but least accurate function and is given by $$f(x)=\frac{1}{2}\left(1+\frac{x}{1+|x|}\right)$$ It is 1.5x faster than the exact logistic with an RMS error of $6\cdot10^{-2}$ in the range $[-10,10]$.
The Schraudolph Approximation
Is drawn from this paper and relies on dark magic involving the IEEE754 definitions of floating-point numbers. From the paper:
After multiplication, the fractional part of y will spill over into the highest-order bits of the mantissa $m$. This spillover is not only harmless, but in fact is highly desirable—under the IEEE-754 format, it amounts to a linear interpolation between neighboring integer exponents. The technique therefore exponentiates real-valued arguments as well as a lookup table with $2^{11}$ entries and linear interpolation.
This function has a relatively low RMS of $1.6\cdot10^{-3}$ and a 1.4x speed-up versus the exact logistic.
log_w_approx_exp_no_overflow16
This function relies on the approximation $$f(x)=\frac{1}{1+\left(1-\frac{x}{n}\right)^{n}}$$ where we've used $n=16$. The value of $n=16$ is that it boils down to very simple assembly code giving a 1.4x speed-up with a decent RMS of $7.0\cdot10^{-3}$. Increasing $n$ would give a better approximation at the cost of worse performance.
Passing the code
double exp_product_approx16(double x){
x = 1 + x / 16;
const auto a = x * x;
const auto b = a * a;
const auto c = b * b;
const auto d = c * c;
return d;
}
inline double log_w_approx_exp_no_overflow16(double x){
return 1/(1+exp_product_approx16(-x));
}
through Godbolt we get
.LCPI0_0:
.quad 0x3fb0000000000000 # double 0.0625
.LCPI0_1:
.quad 0x3ff0000000000000 # double 1
exp_product_approx16(double): # @exp_product_approx16(double)
mulsd xmm0, qword ptr [rip + .LCPI0_0]
addsd xmm0, qword ptr [rip + .LCPI0_1]
mulsd xmm0, xmm0
mulsd xmm0, xmm0
mulsd xmm0, xmm0
mulsd xmm0, xmm0
ret
where we see that each doubling of $n$ introduces only a single additional mulsd instruction.
log_w_approx_exp_no_overflow16_clamped
This function appears to have nearly identical characteristics to log_w_approx_exp_no_overflow16, but is defined as
$$f(x)=\begin{cases}
1 & x\ge n \
\frac{1}{1+\left(1-\frac{x}{n}\right)^{n}} & \textrm{otherwise}
\end{cases}$$
if you squint at the function, you find that it increases from $(-\infty,n)$ and then decreases from $(n,\infty)$:
This is fine if our inputs are all in the range $(-\infty,n]$, but if the inputs are larger than this it becomes problematic. We can handle this either by increasing $n$ (remember that doubling it only adds a single assembly instruction) or by introducing an unlikely if. We choose to use the if in case the input is adversarial and do so with very minor performance costs:
constexpr double exp_product_approx16(double x){
x = 1 + x / 16;
const auto a = x * x;
const auto b = a * a;
const auto c = b * b;
const auto d = c * c;
return d;
}
inline double log_w_approx_exp_no_overflow16_clamped(double x){
if(x >= 16) [[unlikely]] {
return 1;
}
return 1/(1+exp_product_approx16(-x));
}
Note that the way this code is written is important! If we were to write $x^{16}$ as:
x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x
the compiler would generate 16 multiplication instructions! The -ffast-math flag avoids this, but at the cost of potentially reducing the accuracy of math throughout your program.
Similarly, using
std::pow(x, 16)
will generate a call to the std::pow function, which will be slower than the doubling method we've used above.
Recommendation
I recommend you benchmark the approximations on your own system and choose a function with the quality/performance trade-off that works best for you. Personally, I've found log_w_approx_exp_no_overflow16_clamped to be sufficient for my needs and prefer it to the black magic of the Schraudolph approximation. This is especially so since the accuracy of the Schraudolph approximation is fixed while adjusting $n$ allows me to easily tune my accuracy/performance trade-off (see, for example, log_w_approx_exp_no_overflow128).
Benchmarking Code
// Compile with: clang++ -O3 test.cpp
// Functions plotted here: https://www.desmos.com/calculator/nkblxiypxh
#include <chrono>
#include <cmath>
#include <endian.h>
#include <iomanip>
#include <iostream>
#include <string>
constexpr double STEP_SIZE = 0.0000001;
// constexpr double STEP_SIZE = 0.001;
constexpr int NAME_LEN = 40;
constexpr double logistic_orig(double x) {
return x > 0 ? (1.0 / (1.0 + std::exp(-x))) : (1.0 - 1.0 / (1.0 + std::exp(x)));
}
// Based on "A Fast, Compact Approximation of the Exponential Function"
// By Nicol N. Schraudolph
// https://nic.schraudolph.org/pubs/Schraudolph99.pdf
constexpr inline double exp_schraudolph_approx(double x){
constexpr auto c2_to_the_20th = 1 << 20;
constexpr auto ln2 = 0.6931471805599453;
constexpr auto a = c2_to_the_20th / ln2;
constexpr auto b = 1023 * c2_to_the_20th;
constexpr auto c = 60801;
union {
double d;
// NOTE: This works for a little-endian architecture
#if __BYTE_ORDER == __LITTLE_ENDIAN
struct {
int j, i;
} n;
#else
struct {
int i, j;
} n;
#endif
} eco = {};
// Black magic happens here. Read the paper.
eco.n.i = a * x + (b - c);
return eco.d;
}
constexpr double logistic_schraudolph_approx(double x) {
return x > 0 ? (1.0 / (1.0 + exp_schraudolph_approx(-x)))
: (1.0 - 1.0 / (1.0 + exp_schraudolph_approx(x)));
}
constexpr double exp_product_approx128(double x){
x = 1 + x / 128;
const auto a = x * x;
const auto b = a * a;
const auto c = b * b;
const auto d = c * c;
const auto e = d * d;
const auto f = e * e;
const auto g = f * f;
return g;
}
constexpr double exp_product_approx16(double x){
x = 1 + x / 16;
const auto a = x * x;
const auto b = a * a;
const auto c = b * b;
const auto d = c * c;
return d;
}
constexpr double logistic_product_approx128(double x) {
return x > 0 ? (1.0 / (1.0 + exp_product_approx128(-x)))
: (1.0 - 1.0 / (1.0 + exp_product_approx128(x)));
}
constexpr float exp_product_approx_float128(float x){
x = 1 + x / 128;
const auto a = x * x;
const auto b = a * a;
const auto c = b * b;
const auto d = c * c;
const auto e = d * d;
const auto f = e * e;
const auto g = f * f;
return g;
}
constexpr float logistic_product_approx_float128(float x) {
return x > 0 ? (1.0 / (1.0 + exp_product_approx_float128(-x)))
: (1.0 - 1.0 / (1.0 + exp_product_approx_float128(x)));
}
inline double fexp_quintic(double x){
constexpr int64_t mantissa = static_cast<int64_t>(1)<<52;
constexpr int64_t bias = 1023;
constexpr int64_t ishift = mantissa*bias;
constexpr double ln2 = 0.6931471805599453;
constexpr double s1 = -1.90188191959304e-3;
constexpr double s2 = -9.01146535969578e-3;
constexpr double s3 = -5.57129652016652e-2;
constexpr double s4 = -2.40226506959101e-1;
constexpr double s5 = 3.06852819440055e-1;
const double y = x/ln2;
const double yf = y-std::floor(y);
const double y2 = y-((((s1*yf+s2)*yf+s3)*yf+s4)*yf+s5)*yf;
const int64_t i8 = mantissa*y2+ishift;
return *reinterpret_cast<const double*>(&i8);
}
inline double logistic_fexp_quintic_approx(double x) {
return x > 0 ? (1.0 / (1.0 + fexp_quintic(-x)))
: (1.0 - 1.0 / (1.0 + fexp_quintic(x)));
}
inline double logistic_with_abs(double x){
return 0.5*(1+x/(1+std::abs(x)));
}
inline double logistic_with_tanh(double x){
return 0.5*(1+std::tanh(x));
}
inline double logistic_with_erf(double x){
return 0.5*(1+std::erf(std::sqrt(M_PI)*x/2));
}
inline double logistic_with_sqrt(double x){
const auto temp = 0.5/std::sqrt(1+x*x);
return (x<0)?temp:1-temp;
}
inline double logistic_with_atan(double x){
return 0.5*(1+std::atan(M_PI*x/2)*2/M_PI);
}
inline double logistic_with_exp_no_overflow(double x){
return 1/(1+std::exp(-x));
}
inline double log_w_approx_exp_no_overflow128(double x){
return 1/(1+exp_product_approx128(-x));
}
inline double log_w_approx_exp_no_overflow16(double x){
return 1/(1+exp_product_approx16(-x));
}
inline double log_w_approx_exp_no_overflow16_clamped(double x){
// If you squint at exp_product_approx16 you realize it reaches
// y=1 at x=16 and must decrease thereafter, so we use that
// as the clamp value here.
if(x >= 16) [[unlikely]] {
return 1;
}
return 1/(1+exp_product_approx16(-x));
}
template<typename Func>
double time_it(Func func, const std::string& func_name, const double original_time){
const auto start = std::chrono::high_resolution_clock::now();
double diff = 0;
int count = 0;
double maxdiff = -std::numeric_limits<double>::infinity();
for(double x=-10;x<10;x+=STEP_SIZE){
const auto origval = logistic_orig(x);
const auto newval = func(x);
diff += std::pow(origval - newval, 2);
maxdiff = std::max(maxdiff, std::abs(origval - newval));
count++;
}
const auto end = std::chrono::high_resolution_clock::now();
const auto time_per_sample = std::chrono::duration_cast<std::chrono::microseconds>(end-start).count()/static_cast<double>(count);
std::cout<<std::setw(NAME_LEN)<<func_name
<<std::scientific<<std::setprecision(4)<<std::setw(15)<<std::sqrt(diff/count)
<<std::scientific<<std::setprecision(4)<<std::setw(15)<<maxdiff
<<std::fixed<<std::setw(15)<<time_per_sample
<<std::fixed<<std::setw(15)<<original_time/time_per_sample
<<std::setw(15)<<count
<<std::endl;
return time_per_sample;
}
int main(){
std::cout<<std::setw(NAME_LEN)<<"name"
<<std::setw(15)<<"rms_error"
<<std::setw(15)<<"maxdiff"
<<std::setw(15)<<"time_us"
<<std::setw(15)<<"speedup"
<<std::setw(15)<<"samples"
<<std::endl;
double original_speed = 1;
original_speed = time_it(logistic_orig, "logistic_orig", original_speed);
time_it(logistic_schraudolph_approx, "logistic_schraudolph_approx", original_speed);
time_it(logistic_product_approx128, "logistic_product_approx128", original_speed);
time_it(logistic_fexp_quintic_approx, "logistic_fexp_quintic_approx", original_speed);
time_it(logistic_product_approx_float128, "logistic_product_approx_float128", original_speed);
time_it(logistic_with_abs, "logistic_with_abs", original_speed);
time_it(logistic_with_tanh, "logistic_with_tanh", original_speed);
time_it(logistic_with_erf, "logistic_with_erf", original_speed);
time_it(logistic_with_sqrt, "logistic_with_sqrt", original_speed);
time_it(logistic_with_atan, "logistic_with_atan", original_speed);
time_it(logistic_with_exp_no_overflow, "logistic_with_exp_no_overflow", original_speed);
time_it(log_w_approx_exp_no_overflow128, "log_w_approx_exp_no_overflow128", original_speed);
time_it(log_w_approx_exp_no_overflow16, "log_w_approx_exp_no_overflow16", original_speed);
time_it(log_w_approx_exp_no_overflow16_clamped, "log_w_approx_exp_no_overflow16_clamped", original_speed);
return 0;
}
