PyTorch 序列化与保存
在深度学习中,模型的训练通常需要大量的时间和计算资源。为了避免每次需要时都重新训练模型,我们可以将训练好的模型保存下来,以便在以后的任务中直接加载和使用。PyTorch提供了多种方法来序列化和保存模型,本文将详细介绍这些方法。
什么是序列化?
序列化是指将数据结构或对象状态转换为可以存储或传输的格式的过程。在PyTorch中,序列化通常指的是将模型的状态(包括权重和优化器状态)保存到文件中,以便在需要时可以重新加载。
保存和加载模型
保存模型
在PyTorch中,我们可以使用 torch.save()
函数来保存模型。通常,我们会保存模型的 state_dict
,这是一个包含模型所有参数(权重和偏置)的字典。
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 实例化模型
model = SimpleModel()
# 保存模型的 state_dict
torch.save(model.state_dict(), 'model.pth')
加载模型
要加载保存的模型,我们首先需要实例化一个与保存时相同的模型结构,然后使用 torch.load()
加载 state_dict
,最后使用 load_state_dict()
方法将参数加载到模型中。
# 实例化模型
model = SimpleModel()
# 加载模型的 state_dict
model.load_state_dict(torch.load('model.pth'))
# 将模型设置为评估模式
model.eval()
备注
在加载模型后,记得调用 model.eval()
将模型设置为评估模式。这是因为某些层(如 Dropout
和 BatchNorm
)在训练和评估时的行为是不同的。
保存和加载整个模型
除了保存 state_dict
,我们还可以保存整个模型。这种方法更加简单,但不够灵活,因为它依赖于保存时的模型类和代码。
# 保存整个模型
torch.save(model, 'model_entire.pth')
# 加载整个模型
model = torch.load('model_entire.pth')
model.eval()
警告
保存整个模型可能会导致代码的兼容性问题,尤其是在模型类定义发生变化时。因此,推荐使用 state_dict
的方法来保存和加载模型。