视频1 视频21 视频41 视频61 视频文章1 视频文章21 视频文章41 视频文章61 推荐1 推荐3 推荐5 推荐7 推荐9 推荐11 推荐13 推荐15 推荐17 推荐19 推荐21 推荐23 推荐25 推荐27 推荐29 推荐31 推荐33 推荐35 推荐37 推荐39 推荐41 推荐43 推荐45 推荐47 推荐49 关键词1 关键词101 关键词201 关键词301 关键词401 关键词501 关键词601 关键词701 关键词801 关键词901 关键词1001 关键词1101 关键词1201 关键词1301 关键词1401 关键词1501 关键词1601 关键词1701 关键词1801 关键词1901 视频扩展1 视频扩展6 视频扩展11 视频扩展16 文章1 文章201 文章401 文章601 文章801 文章1001 资讯1 资讯501 资讯1001 资讯1501 标签1 标签501 标签1001 关键词1 关键词501 关键词1001 关键词1501 专题2001
PyTorch快速搭建神经网络及其保存提取方法详解
2020-11-27 14:12:26 责编:小采
文档
 本篇文章主要介绍了PyTorch快速搭建神经网络及其保存提取方法详解,现在分享给大家,也给大家做个参考。一起过来看看吧

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
 def __init__(self, n_feature, n_hidden, n_output): 
 super(Net, self).__init__() 
 self.hidden = torch.nn.Linear(n_feature, n_hidden) 
 self.predict = torch.nn.Linear(n_hidden, n_output) 
 
 def forward(self, x): 
 x = F.relu(self.hidden(x)) 
 x = self.predict(x) 
 return x 
 
net1 = Net(2, 10, 2) 
print('方法1:
', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
 torch.nn.Linear(2, 10), 
 torch.nn.ReLU(), 
 torch.nn.Linear(10, 2), 
 ) 
print('方法2:
', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
 # 神经网络结构 
 net1 = torch.nn.Sequential( 
 torch.nn.Linear(1, 10), 
 torch.nn.ReLU(), 
 torch.nn.Linear(10, 1), 
 ) 
 optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
 loss_function = torch.nn.MSELoss() 
 
 # 训练部分 
 for i in range(300): 
 prediction = net1(x) 
 loss = loss_function(prediction, y) 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 
 # 绘图部分 
 plt.figure(1, figsize=(10, 3)) 
 plt.subplot(131) 
 plt.title('net1') 
 plt.scatter(x.data.numpy(), y.data.numpy()) 
 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
 # 保存神经网络 
 torch.save(net1, '7-net.pkl') # 保存整个神经网络的结构和模型参数 
 torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
 net2 = torch.load('7-net.pkl') 
 prediction = net2(x) 
 
 plt.subplot(132) 
 plt.title('net2') 
 plt.scatter(x.data.numpy(), y.data.numpy()) 
 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
 # 首先搭建相同的神经网络结构 
 net3 = torch.nn.Sequential( 
 torch.nn.Linear(1, 10), 
 torch.nn.ReLU(), 
 torch.nn.Linear(10, 1), 
 ) 
 
 # 载入神经网络的模型参数 
 net3.load_state_dict(torch.load('7-net_params.pkl')) 
 prediction = net3(x) 
 
 plt.subplot(133) 
 plt.title('net3') 
 plt.scatter(x.data.numpy(), y.data.numpy()) 
 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()

实验结果:

下载本文
显示全文
专题