Pytorch 保存、读取和修改网络模型
1.模型的保存与读取
模型的保存与读取
- 模型的保存
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构加模型参数
torch.save(vgg16, 'vgg16_method1.pth')
# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')
- 模型的读取
import torch
import torchvision
# 保存方式1 加载模型
model1 = torch.load('vgg16_method1.pth')
print(model1)
# 保存方式2 加载模型 -- 需要创建模型,然后再加载参数
model2 = torchvision.models.vgg16(pretrained=False)
model2.load_state_dict(torch.load('vgg16_method2.pth'))
print(model2)
注意方法一可能会报错,只需要将定义模型的类引入即可,例如: from model_save import
2.网络模型的修改
2.1 冻结模型的梯度
for para in origin_model.parameters(): # 冻结整个模型的参数
para.requires_grad = False
# 冻结指定层的预训练参数:
net.feature[26].weight.requires_grad = False
2.2 在Sequential添加一层
# features 为 nn.Sequential
net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
2.3 修改Sequential中的某一层
net.classifier[6] = nn.Linear(1000, 5) # 前面的 6 根据实际修改
2.4 删除Sequential中的某一层
直接使用nn.Sequential()
对改层设置为空即可
net.features[13] = nn.Sequential()
2.5 删除Sequential的某些层
net.features = nn.Sequential(*list(net.features.children())[:-4]) # 删除网络的后 4 层