Как векторизовать простой цикл for в Python/Numpy

Я нашел десятки примеров векторизации циклов в Python / NumPy. К сожалению, я не понимаю, как я могу уменьшить время вычисления моего простого цикла for, используя векторизованную форму. Возможно ли это в данном случае?

time = np.zeros(185000)
lat1 = np.array(([48.78,47.45],[38.56,39.53],...)) # ~ 200000 rows
lat2 = np.array(([7.78,5.45],[7.56,5.53],...)) # same number of rows as time
for ii in np.arange(len(time)):
    pos = np.argwhere( (lat1[:,0]==lat2[ii,0]) and 
                       (lat1[:,1]==lat2[ii,1]) )
    if pos.size:
        pos = int(pos)
        time[ii] = dtime[pos]

2 ответов


вероятно, самый быстрый способ найти все совпадения-отсортировать оба массива и пройти через них вместе, как в этом рабочем примере:

import numpy as np

def is_less(a, b):
    # this ugliness is needed because we want to compare lexicographically same as np.lexsort(), from the last column backward
    for i in range(len(a)-1, -1, -1):
        if a[i]<b[i]: return True
        elif a[i]>b[i]: return False
    return False

def is_equal(a, b):
    for i in range(len(a)):
        if a[i] != b[i]: return False
    return True

# lat1 = np.array(([48.78,47.45],[38.56,39.53]))
# lat2 = np.array(([7.78,5.45],[48.78,47.45],[7.56,5.53]))
lat1 = np.load('arr.npy')
lat2 = np.load('refarr.npy')

idx1 = np.lexsort( lat1.transpose() )
idx2 = np.lexsort( lat2.transpose() )
ii = 0
jj = 0
while ii < len(idx1) and jj < len(idx2):
    a = lat1[ idx1[ii] , : ]
    b = lat2[ idx2[jj] , : ]
    if is_equal( a, b ):
        # do stuff with match
        print "match found: lat1=%s lat2=%s %d and %d" % ( repr(a), repr(b), idx1[ii], idx2[jj] )
        ii += 1
        jj += 1
    elif is_less( a, b ):
        ii += 1
    else:
        jj += 1

Это может быть не совсем питоническим (возможно, кто-то может подумать о более приятной реализации с использованием генераторов или itertools?) но трудно представить себе какой-либо метод, который опирается на поиск одной точки за раз, обгоняя это по скорости.


вот решение. Я не уверен, что его можно векторизовать. Если вы хотите сделать его устойчивым к "float comparing error", вы должны изменить is_less и is_greater. Весь алгоритм-это просто двоичный поиск.

import numpy as np

#lexicographicaly compare two points - a and b

def is_less(a, b):
    i = 0
    while i<len(a):
        if a[i]<b[i]:
            return True
        else:
            if a[i]>b[i]:
                return False
        i+=1
    return False

def is_greater(a, b):
    i = 0
    while i<len(a):
        if a[i]>b[i]:
            return True
        else:
            if a[i]<b[i]:
                return False
        i+=1
    return False


def binary_search(a, x, lo=0, hi=None):
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo+hi)//2
        midval = a[mid]
        if is_less(midval, x):
            lo = mid+1
        elif is_greater(midval, x):
            hi = mid
        else:
            return mid
    return -1

def lex_sort(v): #sort by 1 and 2 column respectively
    #return v[np.lexsort((v[:,2],v[:,1]))]
    order = range(1, v.shape[1])
    return v[np.lexsort(tuple(v[:,i] for i in order[::-1]))]

def sort_and_index(arr):
    ind = np.indices((len(arr),)).reshape((len(arr), 1))
    arr = np.hstack([ind, arr]) # add an index column as first column
    arr = lex_sort(arr)
    arr_cut = arr[:,1:] # an array to do binary search in
    arr_ind = arr[:,:1] # shuffled indices
    return arr_ind, arr_cut


#lat1 = np.array(([1,2,3], [3,4,5], [5,6,7], [7,8,9])) # ~ 200000 rows
lat1 = np.arange(1,800001,1).reshape((200000,4))
#lat2 = np.array(([3,4,5], [5,6,7], [7,8,9], [1,2,3])) # same number of rows as time
lat2 = np.arange(101,800101,1).reshape((200000,4))

lat1_ind, lat1_cut = sort_and_index(lat1)

time_arr = np.zeros(200000)
import time
start = time.time()

for ii, elem in enumerate(lat2):
    pos = binary_search(lat1_cut, elem)
    if pos == -1:
        #Not found
        continue
    pos = lat1_ind[pos][0]
    #print "element in lat2 with index",ii,"has position",pos,"in lat1"
print time.time()-start

комментируемой печати-это место, где у вас есть соответствующие индексы lat1 и lat2. Работает в течение 7 секунд на 200000 строк.