It's somewhat inconvenient, because it approximates tanh(0.5*ln(2)*x), but almost every time you want to use a scaler anyway so it shouldn't matter.
It's based on a definition
Code: Select all
tanh(x)= (e^2x-1)/(e^2x+1)
This code computes 2^x-1 inside without subtractive precision loss.
SSE2:
Code: Select all
//Approximates tanh(0.5*ln(2)*x)
static __m128 tanh2x_mpv(__m128 x)
{
const __m128 one = _mm_set1_ps(1.0f);
const __m128 signmask = _mm_set1_ps(-0.0f);
const __m128 max_val = _mm_set1_ps(30.0f);
const __m128 c1 = _mm_set1_ps(6.931471825e-01);
const __m128 c2 = _mm_set1_ps(2.402264923e-01);
const __m128 c3 = _mm_set1_ps(5.550357327e-02);
const __m128 c4 = _mm_set1_ps(9.618237615e-03);
const __m128 c5 = _mm_set1_ps(1.339077600e-03);
const __m128 c6 = _mm_set1_ps(1.540359954e-04);
__m128 signs = _mm_and_ps(x, signmask);
x = _mm_andnot_ps(signmask, x);
x = _mm_min_ps(x, max_val);
__m128 f = x;
__m128i i = _mm_cvtps_epi32(f);
f = _mm_sub_ps(f, _mm_cvtepi32_ps(i));
__m128 f2 = _mm_mul_ps(f, f);
__m128 p = _mm_add_ps(_mm_mul_ps(c6, f), c5);
p = _mm_mul_ps(p, f2);
p = _mm_add_ps(p, _mm_add_ps(_mm_mul_ps(c4, f), c3));
p = _mm_mul_ps(p, f2);
p = _mm_add_ps(p, _mm_add_ps(_mm_mul_ps(c2, f), c1));
p = _mm_mul_ps(p, f);
i = _mm_slli_epi32(i, 23);
__m128 biased_expm = _mm_castsi128_ps(_mm_sub_epi32(_mm_castps_si128(one), i));
__m128 exp_cor = _mm_sub_ps(one, biased_expm);
__m128 exp_cor_p = _mm_add_ps(one, biased_expm);
__m128 exp2xm1 = _mm_xor_ps(signs,_mm_add_ps(p, exp_cor));
__m128 exp2xp1 = _mm_add_ps(p, exp_cor_p);
return _mm_div_ps(exp2xm1, exp2xp1);
}
Code: Select all
//Approximates tanh(0.5*ln(2)*x)
static __m256 tanh2x_mpv_fma(__m256 x)
{
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 signmask = _mm256_set1_ps(-0.0f);
const __m256 max_val = _mm256_set1_ps(30.0f);
const __m256 c1 = _mm256_set1_ps(6.931471825e-01);
const __m256 c2 = _mm256_set1_ps(2.402264923e-01);
const __m256 c3 = _mm256_set1_ps(5.550357327e-02);
const __m256 c4 = _mm256_set1_ps(9.618237615e-03);
const __m256 c5 = _mm256_set1_ps(1.339077600e-03);
const __m256 c6 = _mm256_set1_ps(1.540359954e-04);
__m256 signs = _mm256_and_ps(x, signmask);
x = _mm256_andnot_ps(signmask, x);
x = _mm256_min_ps(x, max_val);
__m256 f = x;
__m256i i = _mm256_cvtps_epi32(f);
f = _mm256_sub_ps(f, _mm256_cvtepi32_ps(i));
__m256 p = c6;
p = _mm256_fmadd_ps(p, f, c5);
p = _mm256_fmadd_ps(p, f, c4);
p = _mm256_fmadd_ps(p, f, c3);
p = _mm256_fmadd_ps(p, f, c2);
p = _mm256_fmadd_ps(p, f, c1);
i = _mm256_slli_epi32(i, 23);
__m256 biased_expm = _mm256_castsi256_ps(_mm256_sub_epi32(_mm256_castps_si256(one), i));
__m256 exp_cor = _mm256_sub_ps(one, biased_expm);
__m256 exp_cor_p = _mm256_add_ps(one, biased_expm);
__m256 exp2xm1 = _mm256_xor_ps(signs, _mm256_fmadd_ps(p, f, exp_cor));
__m256 exp2xp1 = _mm256_fmadd_ps(p, f, exp_cor_p);
return _mm256_div_ps(exp2xm1, exp2xp1);
}
Code: Select all
MSVC vector tanh: 5.08284 c/v
tanh2x_mpv_fma: 1.40598 c/v
tanh2x_mpv: 3.79001 c/v
Accuracy:
Code: Select all
3.0507440678775311 ulps on [3 * std::numeric_limits<float>::min(), 300]
~3e-7 relative