Самая быстрая реализация экспоненциальной функции с помощью SSE

Я ищу аппроксимацию экспоненциальной функции, работающей на элементе SSE. А именно - __m128 exp( __m128 x ).

у меня есть реализация, которая быстра, но кажется очень низкой точностью:

static inline __m128 FastExpSse(__m128 x)
{
    __m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
    __m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
    __m128  m87 = _mm_set1_ps(-87);
    // fast exponential function, x should be in [-87, 87]
    __m128 mask = _mm_cmpge_ps(x, m87);

    __m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
    return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}

может ли кто-нибудь иметь реализацию с лучшей точностью, но так же быстро (или быстрее)?

Я был бы счастлив, если бы я написал в стиле Си.

Спасибо.

3 ответов


код C ниже является переводом в SSE внутреннеприсущие алгоритма, который я использовал в предыдущий ответ к аналогичному вопросу.

основная идея состоит в том, чтобы преобразовать вычисление стандартной экспоненциальной функции в вычисление степени 2:expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504). Мы разделились t = x * 1.44269504 в целое число i и часть f, таких, что t = i + f и 0 <= f <= 1. Теперь мы можем вычислить 2f С полиномиальной аппроксимацией, затем масштабируйте результат по 2я добавлять i в поле экспоненты результата с плавающей запятой с одной точностью.

одна проблема, которая существует с реализацией SSE, заключается в том, что мы хотим вычислить i = floorf (t), но нет быстрого способа вычислить Schraudolph в основном используется линейная аппроксимация 2f ~= 1.0 + f на f в [0,1], и его точность может быть улучшена путем добавления квадратичной срок. Самая умная часть подхода Schraudolph является вычислений 2я * 2f без явного разделения целочисленной части i = floor(x * 1.44269504) из фракции. Я не вижу способа распространить этот трюк на квадратичное приближение, но, безусловно, можно объединить the floor() расчет с Schraudolph с квадратичной аппроксимации выше:

/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 f, p, r;
    __m128i t, j;
    const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
    const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
    const __m128 c0 = _mm_set1_ps (0.3371894346f);
    const __m128 c1 = _mm_set1_ps (0.657636276f);
    const __m128 c2 = _mm_set1_ps (1.00172476f);

    t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
    j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
    t = _mm_sub_epi32 (t, j);
    f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

хорошее повышение точности в моем алгоритме (реализация FastExpSse в ответе выше) может быть получено за счет вычитания целого числа и деления с плавающей запятой с помощью FastExpSse(x/2)/FastExpSse (- x/2) вместо FastExpSse(x). Трюк здесь состоит в том, чтобы установить параметр сдвига (298765 выше) на ноль, чтобы кусочно-линейные аппроксимации в числителе и знаменателе выровнялись, чтобы дать вам существенную отмену ошибок. Сверни ее в одну. функция:

__m128 BetterFastExpSse (__m128 x)
{
  const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2));  // to get exp(x/2)
  const __m128i b = _mm_set1_epi32 (127 * (1 << 23));       // NB: zero shift!
  __m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
  __m128i s = _mm_add_epi32 (b, r);
  __m128i t = _mm_sub_epi32 (b, r);
  return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}

(Я не аппаратный парень - насколько плохим убийцей производительности является подразделение здесь?)

Если вам нужен exp(x), чтобы получить y = tanh (x) (например, для нейронных сетей), используйте FastExpSse с нулевым сдвигом следующим образом:

a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);

чтобы получить тот же тип преимущества отмены ошибок. Логистическая функция работает аналогично, используя FastExpSse(x/2)/(FastExpSse (x/2) + FastExpSse (- x/2)) с нулевым сдвигом. (Это просто чтобы показать принцип - вы, очевидно, не хотите оценить FastExpSse несколько раз здесь, но сверните его в одну функцию по линиям BetterFastExpSse выше.)

Я разработал серию аппроксимаций более высокого порядка из этого, все более точных, но и медленнее. Неопубликованный, но рад сотрудничать, если кто-то хочет дать им спину.

и, наконец, для некоторого удовольствия: используйте заднюю передачу, чтобы получить FastLogSse. Цепочки с FastExpSse дает вам как оператор и отмене ошибку, и соз молниеносная степенная функция...


возвращаясь к моим заметкам с тех пор, я изучил способы повышения точности без использования деления. Я использовал тот же трюк с переинтерпретацией как поплавок, но применил полиномиальную коррекцию к мантиссе, которая была по существу вычислена в 16-битной арифметике с фиксированной точкой (единственный способ сделать это быстро).

кубический resp. квартирные версии дают вам 4 resp. 5 значащих цифр точности. Не было смысла увеличивать порядок дальше этого, так как шум затем низкоточная арифметика начинает заглушать ошибку полиномиального приближения. Вот простые версии C:

#include <stdint.h>

float fastExp3(register float x)  // cubic spline approximation
{
    union { float f; int32_t i; } reinterpreter;

    reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
    int32_t m = (reinterpreter.i >> 7) & 0xFFFF;  // copy mantissa
    // empirical values for small maximum relative error (8.34e-5):
    reinterpreter.i +=
         ((((((((1277*m) >> 14) + 14825)*m) >> 14) - 79749)*m) >> 11) - 626;
    return reinterpreter.f;
}

float fastExp4(register float x)  // quartic spline approximation
{
    union { float f; int32_t i; } reinterpreter;

    reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
    int32_t m = (reinterpreter.i >> 7) & 0xFFFF;  // copy mantissa
    // empirical values for small maximum relative error (1.21e-5):
    reinterpreter.i += (((((((((((3537*m) >> 16)
        + 13668)*m) >> 18) + 15817)*m) >> 14) - 80470)*m) >> 11);
    return reinterpreter.f;
}

квартик подчиняется (fastExp4 (0f) == 1f), что может быть важно для алгоритмов итерации с фиксированной точкой.

насколько эффективны эти целочисленные последовательности multiply-shift-add в SSE? На архитектурах, где арифметика float так же быстра, можно использовать это вместо этого, уменьшая арифметический шум. Это по существу даст cubic и квартирные расширения ответа @njuffa выше.