Pytorch 保存、读取和修改网络模型

1.模型的保存与读取

模型的保存与读取

  • 模型的保存
import torch
 import torchvision
 ​
 vgg16 = torchvision.models.vgg16(pretrained=False)
 # 保存方式1,模型结构加模型参数
 torch.save(vgg16, 'vgg16_method1.pth')
 ​
 # 保存方式2,模型参数(官方推荐)
 torch.save(vgg16.state_dict(), 'vgg16_method2.pth')
  • 模型的读取
import torch
 import torchvision
 ​
 # 保存方式1 加载模型
 model1 = torch.load('vgg16_method1.pth')
 print(model1)
 ​
 # 保存方式2 加载模型 -- 需要创建模型,然后再加载参数
 model2 = torchvision.models.vgg16(pretrained=False)
 model2.load_state_dict(torch.load('vgg16_method2.pth'))
 print(model2)
注意方法一可能会报错,只需要将定义模型的类引入即可,例如: from model_save import

2.网络模型的修改

2.1 冻结模型的梯度

for para in origin_model.parameters():  # 冻结整个模型的参数
 para.requires_grad = False
 ​
 # 冻结指定层的预训练参数:
 net.feature[26].weight.requires_grad = False

2.2 在Sequential添加一层

# features 为 nn.Sequential
 net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))

2.3 修改Sequential中的某一层

net.classifier[6] = nn.Linear(1000, 5)  # 前面的 6 根据实际修改

2.4 删除Sequential中的某一层

直接使用nn.Sequential()对改层设置为空即可

net.features[13] = nn.Sequential()

2.5 删除Sequential的某些层

net.features = nn.Sequential(*list(net.features.children())[:-4])  # 删除网络的后 4 层
最后修改:2022 年 06 月 25 日 09 : 24 PM
如果觉得我的文章对你有用,请随意赞赏