Код Numba медленнее, чем чистый python
Я работал над ускорением расчета пересчета для фильтра частиц. Поскольку у python есть много способов ускорить его, я бы попробовал их все. К сожалению, версия numba невероятно медленная. Поскольку Numba должна привести к ускорению, я предполагаю, что это ошибка с моей стороны.
Я пробовал 4 разные версии:
- Numba
- Python
- включает в себя
- на Cython
код для каждого ниже:
import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample
@nb.autojit
def numba_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def python_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def numpy_resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = sp.cumsum(qs)
for j, key in enumerate(rands):
i = sp.argmax(lookup>key)
results[j] = xs[i]
return results
#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs,
np.ndarray[DTYPE_t, ndim=1] xs,
np.ndarray[DTYPE_t, ndim=1] rands):
if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
raise ValueError("Arrays must have same shape")
assert qs.dtype == xs.dtype == rands.dtype == DTYPE
cdef unsigned int n = qs.shape[0]
cdef unsigned int i, j
cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
"""
if __name__ == '__main__':
n = 100
xs = np.arange(n, dtype=np.float64)
qs = np.array([1.0/n,]*n)
rands = np.random.rand(n)
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Python Function:"
%timeit python_resample(qs, xs, rands)
print "Timing Numpy Function:"
%timeit numpy_resample(qs, xs, rands)
print "Timing Cython Function:"
%timeit cython_resample(qs, xs, rands)
это приводит к следующему выводу:
Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop
есть идеи, почему код numba настолько медленный? я предположил, что это будет по крайней мере сопоставимо с Numpy.
Примечание: Если у кого-нибудь есть идеи о том, как ускорить образцы кода Numpy или Cython, это тоже было бы неплохо:) мой главный вопрос касается Numba.
2 ответов
проблема в том, что numba не может интуитивно определить тип lookup
. Если вы ставите print nb.typeof(lookup)
в вашем методе вы увидите, что numba рассматривает его как объект, который является медленным. Обычно я просто определяю тип lookup
в местном Дикте, но я получал странную ошибку. Вместо этого я просто создал небольшую оболочку, чтобы я мог явно определить типы ввода и вывода.
@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
return np.cumsum(x)
@nb.autojit
def numba_resample2(qs, xs, rands):
n = qs.shape[0]
#lookup = np.cumsum(qs)
lookup = numba_cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
тогда мои тайминги:
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)
Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop
вы можете пойти даже немного быстрее, если вы используете jit
вместо autojit
:
@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
для меня это снижает его с 15,3 микросекунд до 12,5 микросекунд, но все равно впечатляет, насколько хорошо работает autojit.
быстрее numpy
версия (10x ускорение по сравнению с numpy_resample
)
def numpy_faster(qs, xs, rands):
lookup = np.cumsum(qs)
mm = lookup[None,:]>rands[:,None]
I = np.argmax(mm,1)
return xs[I]