Октава / Matlab: эффективный calc внутреннего продукта Фробениуса?

У меня есть две матрицы A и B, и я хочу получить:

trace(A*B)

Если я не ошибаюсь это называется внутренний продукт Фробениуса.

моя забота здесь об эффективности. Я просто боюсь, что этот прямой подход сначала сделает все умножение (мои матрицы-тысячи строк/cols), и только затем возьмет след продукта, в то время как операция, в которой я действительно нуждаюсь, намного проще. Есть ли функция или синтаксис для этого эффективно?

3 ответов


правильно...суммирование элементно-мудрых продуктов будет быстрее:

n = 1000

A = randn(n);
B = randn(n);

tic
sum(sum(A .* B));
toc

tic
sum(diag(A * B'));
toc
Elapsed time is 0.010015 seconds.
Elapsed time is 0.130514 seconds.

sum(sum(A.*B)) избегает делать полное умножение матрицы


как насчет использования векторного умножения?

(A(:)')*B(:)

проверка времени выполнения

сравнение четырех вариантов с A и B размера 1000 на 1000:
1. вектор внутреннего продукта:A(:)'*B(:) (этот ответ) взял только 0.0011 sec.
2. Использование элемента мудрое умножение sum(sum(A.*B)) (Джонответ) взял 0.0035 sec.
3. Трейс!--8--> (предложено ОП) взял 0.054 sec.
4. Сумма диагонали sum(diag(A*B')) (опция отклонено Джон) взял 0.055 sec.

Take home message: Matlab чрезвычайно эффективен, когда дело доходит до матричного/векторного продукта. Использование векторного внутреннего продукта х3 раза быстрее даже эффективный элемент-мудрое решение умножения.


эталонный код Код, используемый для обеспечения проверки времени выполнения

t=zeros(1,4);
n=1000; % size of matrices
it=100; % average results over XX trails
for ii=1:it, 
    % random inputs
    A=rand(n);
    B=rand(n); 
    % John's rejected solution
    tic; 
    n1=sum(diag(A*B'));
    t(1)=t(1)+toc;
    % element-wise solution
    tic;
    n2=sum(sum(A.*B));
    t(2)=t(2)+toc;
    % MOST efficient solution - using vector product
    tic;
    n3=A(:)'*B(:);
    t(3)=t(3)+toc;
    % using trace
    tic;
    n4=trace(A*B');
    t(4)=t(4)+toc;
    % make sure everything is correct
    assert(abs(n1-n2)<1e-8 && abs(n3-n4)<1e-8 && abs(n1-n4)<1e-8);
end;
t./it

теперь вы можете запустить этот тест в клик.