Эффективное вычисление взвешенного расстояния в MATLAB

несколько посты exist об эффективном вычислении попарных расстояний в MATLAB. Эти сообщения, как правило, касаются быстрого вычисления евклидова расстояния между большим количеством точек.

мне нужно создать функцию, которая быстро вычисляет попарные различия между меньшим количеством точек (обычно менее 1000 пар). В рамках более широкой схемы программы, которую я пишу, эта функция будет выполняться многими тысячи раз, поэтому даже небольшие увеличения в эффективности важны. Функция должна быть гибкой двумя способами:

  1. при любом заданном вызове метрика расстояния может быть евклидовой или городской.
  2. взвешиваются размеры данных.

насколько я могу судить, решение этой конкретной проблемы не было опубликовано. В statstics Toolbox предлагает pdist и pdist2, которые принимают много различных расстояний функции, но не вес. Я видел расширения этих функций, которые позволяют взвешивать, но эти расширения не позволяют пользователям выбирать различные функции расстояния.

В идеале я хотел бы избежать использования функций из набора инструментов статистики (я не уверен, что пользователь функции будет иметь доступ к этим наборам инструментов).

Я написал две функции для выполнения этой задачи. Первый использует сложные вызовы repmat и permute, а второй просто использует for-loops.

function [D] = pairdist1(A, B, wts, distancemetric)

% get some information about the data
    numA = size(A,1);
    numB = size(B,1);

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else error('Function only accepts "cityblock" and "euclidean" distance')
    end

%   format weights for multiplication
    wts = repmat(wts,[numA,1,numB]);

%   get featural differences between A and B pairs
    A = repmat(A,[1 1 numB]);
    B = repmat(permute(B,[3,2,1]),[numA,1,1]);
    differences = abs(A-B).^r;

%   weigh difference values before combining them
    differences = differences.*wts;
    differences = differences.^(1/r);

%   combine features to get distance
    D = permute(sum(differences,2),[1,3,2]);
end

и:

function [D] = pairdist2(A, B, wts, distancemetric)

% get some information about the data
    numA = size(A,1);
    numB = size(B,1);

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else error('Function only accepts "cityblock" and "euclidean" distance')
    end

%   use for-loops to generate differences
    D = zeros(numA,numB);
    for i=1:numA
        for j=1:numB
            differences = abs(A(i,:) - B(j,:)).^(1/r);
            differences = differences.*wts;
            differences = differences.^(1/r);    
            D(i,j) = sum(differences,2);
        end
    end
end

вот тесты производительности:

A = rand(10,3);
B = rand(80,3);
wts = [0.1 0.5 0.4];
distancemetric = 'cityblock';


tic
D1 = pairdist1(A,B,wts,distancemetric);
toc

tic
D2 = pairdist2(A,B,wts,distancemetric);
toc

Elapsed time is 0.000238 seconds.
Elapsed time is 0.005350 seconds.

ясно, что версия repmat-and-permute работает намного быстрее, чем версия double-for-loop, по крайней мере, для небольших наборов данных. Но я также знаю,что вызовы repmat часто замедляют работу. Поэтому мне интересно, есть ли у кого-нибудь в сообществе SO какие-либо советы, чтобы предложить повысить эффективность функция!

редактировать

@Luis Mendo предложил хорошую очистку функции repmat-and-permute с помощью bsxfun. Я сравнил его функцию с моим оригиналом на наборах данных разного размера:

comparison

по мере увеличения данных версия bsxfun становится явным победителем!

правка #2

Я закончил писать функцию, и она доступна на github [ссылке]. Я в конечном итоге нашел довольно хороший векторизованный метод вычисления евклидова расстояния [ссылке], поэтому я использую этот метод в евклидовом случае, и я принял @Divakar по советы для города-блока. Он все еще не так быстр, как pdist2, но его нужно быстрее, чем любой из подходов, которые я изложил ранее в этом посте, и легко принимает взвешивания.

2 ответов


вы можете заменить repmat by bsxfun. Это позволяет избежать явного повторения, поэтому он более эффективен для памяти и, вероятно, быстрее:

function D = pairdist1(A, B, wts, distancemetric)

    if strcmp(distancemetric,'cityblock')
        r=1;
    elseif strcmp(distancemetric,'euclidean')
        r=2;
    else
        error('Function only accepts "cityblock" and "euclidean" distance')
    end

    differences  = abs(bsxfun(@minus, A, permute(B, [3 2 1]))).^r;
    differences = bsxfun(@times, differences, wts).^(1/r);
    D = permute(sum(differences,2),[1,3,2]);

end

на r = 1 ("cityblock" case), вы можете использовать bsxfun чтобы получить элементарные вычитания, а затем использовать matrix-multiplication, что должно ускорить процесс. Реализация будет выглядеть примерно так -

%// Calculate absolute elementiwse subtractions
absm = abs(bsxfun(@minus,permute(A,[1 3 2]),permute(B,[3 1 2])));

%// Perform matrix multiplications with the given weights and reshape
D = reshape(reshape(absm,[],size(A,2))*wts(:),size(A,1),[]);