Лучший способ сохранить обученную модель в PyTorch?
Я искал альтернативные способы сохранить обученную модель в PyTorch. До сих пор я нашел две альтернативы.
- факел.save () сохранить модель и факел.load () для загрузки модели.
- модель.state_dict() сохранить обученную модель и модель.load_state_dict() для загрузки сохраненной модели.
Я дошел до этого обсуждение где подход 2 рекомендуется использовать подход 1.
мой вопрос в том, почему второй подход является предпочтительным? Это только потому, что факел.nn модули имеют эти две функции, и нам рекомендуется их использовать?
2 ответов
Я нашел на этой странице на их репозитории github я просто вставлю содержимое здесь.
рекомендуемый подход для сохранения модели
есть два основных подхода к сериализации и восстановления модели.
первый (рекомендуется) сохраняет и загружает только параметры модели:
torch.save(the_model.state_dict(), PATH)
позже:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
второй сохраняет и загружает всю модель:
torch.save(the_model, PATH)
затем позже:
the_model = torch.load(PATH)
однако в этом случае сериализованные данные привязаны к определенным классам и точная используемая структура каталогов, поэтому она может сломать в различных путях когда использованы в других проектах, или после серьезных операций рефакторинга.
это зависит от того, что вы хотите сделать.
Case # 1: сохраните модель, чтобы использовать ее самостоятельно для вывода: вы сохраняете модель, восстанавливаете ее, а затем меняете модель на режим оценки. Это делается потому, что у вас обычно есть BatchNorm
и Dropout
слои, которые по умолчанию находятся в режиме поезда на конструкции:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
случай # 2: сохранить модель, чтобы возобновить обучение позже: Если вам нужно продолжать обучение модели, которую вы собираетесь сохранить, нужно сохранить больше, чем просто модель. Вы также должны сохранять состояние оптимизатор, эпох, результат и т. д. Вы бы сделали это так:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
чтобы возобновить обучение, вы бы сделали такие вещи, как:state = torch.load(filepath)
, а затем, чтобы восстановить состояние каждого отдельного объекта, что-то вроде этого:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
так как вы возобновляете обучение,НЕ вызов model.eval()
после восстановления состояний при загрузке.
Дело № 3: Модель для использования кем-то другим без доступа к вашему коду:
В Tensorflow вы можете создать .pb
файл, определяющий архитектуру и вес модели. Это очень удобно, особенно при использовании Tensorflow serve
. Эквивалентным способом сделать это в Pytorch было бы:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
этот способ все еще не пуленепробиваемый, и поскольку pytorch все еще претерпевает много изменений, я бы не рекомендовал его.