Что означает параметр retain graph в методе backward() переменной?

Я прохожу через нейронная передача pytorch учебник и я смущен использованием retain_variable(устарел, теперь называется retain_graph). Пример кода:

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_variables=True):
        #Why is retain_variables True??
        self.loss.backward(retain_variables=retain_variables)
        return self.loss

С документация

retain_graph (bool, необязательно) - если False, граф, используемый для вычисления град будет освобожден. Обратите внимание, что почти во всех случаях установка этого опция True не нужна и часто может быть обработана во многом более эффективный способ. По умолчанию используется значение create_graph.

установите retain_graph= True, мы не освобождаем память, выделенную для графика на обратном проходе. В чем преимущество сохранения этой памяти, зачем она нам нужна?

2 ответов


@cleros довольно по поводу использования retain_graph=True. По сути, он сохранит любую необходимую информацию для вычисления определенной переменной, чтобы мы могли сделать обратную передачу по ней.

наглядный пример

enter image description here

предположим, что у нас есть график расчета показано выше. Переменная d и e является выходом, и a вход. Например,

import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

когда мы делаем d.backward(), это нормально. После этого вычисления часть графика, которая вычисляет d будет освобожден по умолчанию для экономии памяти. Так что если мы сделаем e.backward(), появится сообщение об ошибке. Для того, чтобы сделать e.backward(), мы должны установить параметр retain_graph to True на d.backward(), то есть,

d.backward(retain_graph=True)

если вы используете retain_graph=True в вашем обратном методе вы можете сделать Назад в любое время:

d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!

более полезное обсуждение можно найти здесь.

реальный случай использования

прямо сейчас, реальный случай использования многозадачного обучения, где у вас есть несколько потерь, которые могут быть на разных уровнях. Предположим, что у вас есть 2 потери: loss1 и loss2 и они находятся в разных слоях. Для того, чтобы backprop градиент loss1 и loss2 w.r.t к learnable весу вашей сети независимо. Вы должны использовать retain_graph=True на backward() метод в первом обратном распространении потеря.

# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters

Это очень полезная функция, когда у вас есть несколько выходных данных сети. Вот полностью составленный пример: представьте, что вы хотите построить некоторую случайную сверточную сеть, которую вы можете задать два вопроса: содержит ли входное изображение кошку и содержит ли изображение автомобиль?

один из способов сделать это - иметь сеть, которая разделяет сверточные слои, но имеет два параллельных слоя классификации (простите мой ужасный график ASCII, но это предполагается, что три convlayers, а затем три полностью Соединенных слоя, один для кошек и один для автомобилей):

                    -- FC - FC - FC - cat?
Conv - Conv - Conv -|
                    -- FC - FC - FC - car?

учитывая картину, что мы хотим запустить обе ветви, при обучении сети мы можем сделать это несколькими способами. Во-первых (что, вероятно, было бы лучше всего здесь, иллюстрируя, насколько плох пример), мы просто вычисляем потерю по обеим оценкам и суммируем потерю, а затем возвращаемся.

однако есть еще один сценарий, в котором мы хотим сделать это последовательно. Сначала мы хотим backprop через одну ветку, а затем через другую (у меня был этот прецедент раньше, поэтому он не полностью составлен). В этом случае, работает .backward() на одном графике также уничтожит любую информацию градиента в сверточных слоях, а сверточные вычисления второй ветви (поскольку они являются единственными, совместно используемыми с другой ветвью) больше не будут содержать график! Это означает, что когда мы пытаемся backprop через вторую ветку, Pytorch выдаст ошибку, так как не может найти график, соединяющий вход с выходом! В этих случаях мы можем решить проблему простым удержанием графика на первом обратном проходе. Затем график не будет потребляться, а будет потребляться только первым обратным проходом, который не требует его сохранения.

EDIT: если вы сохраняете график на всех обратных проходах, неявные определения графов, прикрепленные к выходным переменным, никогда не будут освобождены. Здесь может быть usecase как но я ничего не могу придумать. Поэтому в целом вы должны убедиться, что последний обратный проход освобождает память, не сохраняя информацию о графике.

Что касается того, что происходит для нескольких обратных проходов: как вы догадались, pytorch накапливает градиенты, добавляя их на место (к переменной/parameters .grad свойства). Это может быть очень полезно, Так как это означает, что цикл над партией и обработка ее один раз за раз, накапливая градиенты в конце, будет делать то же самое шаг оптимизации как выполнение полного пакетного обновления (который также суммирует все градиенты). Хотя полностью пакетированное обновление может быть более распараллелено и поэтому в целом предпочтительно, бывают случаи, когда пакетированное вычисление либо очень, очень трудно реализовать, либо просто невозможно. Однако, используя это накопление, мы все еще можем полагаться на некоторые из приятных стабилизирующих свойств, которые приносит дозирование. (Если не на увеличение производительности)