Лучший способ сохранить обученную модель в PyTorch?

Я искал альтернативные способы сохранить обученную модель в PyTorch. До сих пор я нашел две альтернативы.

  1. факел.save () сохранить модель и факел.load () для загрузки модели.
  2. модель.state_dict() сохранить обученную модель и модель.load_state_dict() для загрузки сохраненной модели.

Я дошел до этого обсуждение где подход 2 рекомендуется использовать подход 1.

мой вопрос в том, почему второй подход является предпочтительным? Это только потому, что факел.nn модули имеют эти две функции, и нам рекомендуется их использовать?

2 ответов


Я нашел на этой странице на их репозитории github я просто вставлю содержимое здесь.


рекомендуемый подход для сохранения модели

есть два основных подхода к сериализации и восстановления модели.

первый (рекомендуется) сохраняет и загружает только параметры модели:

torch.save(the_model.state_dict(), PATH)

позже:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

второй сохраняет и загружает всю модель:

torch.save(the_model, PATH)

затем позже:

the_model = torch.load(PATH)

однако в этом случае сериализованные данные привязаны к определенным классам и точная используемая структура каталогов, поэтому она может сломать в различных путях когда использованы в других проектах, или после серьезных операций рефакторинга.


это зависит от того, что вы хотите сделать.

Case # 1: сохраните модель, чтобы использовать ее самостоятельно для вывода: вы сохраняете модель, восстанавливаете ее, а затем меняете модель на режим оценки. Это делается потому, что у вас обычно есть BatchNorm и Dropout слои, которые по умолчанию находятся в режиме поезда на конструкции:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

случай # 2: сохранить модель, чтобы возобновить обучение позже: Если вам нужно продолжать обучение модели, которую вы собираетесь сохранить, нужно сохранить больше, чем просто модель. Вы также должны сохранять состояние оптимизатор, эпох, результат и т. д. Вы бы сделали это так:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

чтобы возобновить обучение, вы бы сделали такие вещи, как:state = torch.load(filepath), а затем, чтобы восстановить состояние каждого отдельного объекта, что-то вроде этого:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

так как вы возобновляете обучение,НЕ вызов model.eval() после восстановления состояний при загрузке.

Дело № 3: Модель для использования кем-то другим без доступа к вашему коду: В Tensorflow вы можете создать .pb файл, определяющий архитектуру и вес модели. Это очень удобно, особенно при использовании Tensorflow serve. Эквивалентным способом сделать это в Pytorch было бы:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

этот способ все еще не пуленепробиваемый, и поскольку pytorch все еще претерпевает много изменений, я бы не рекомендовал его.