返回
PyTorch保存和加载模型以及查看模型结构的方法,通俗易懂的入门指南
人工智能
2023-09-25 04:44:29
一、保存和加载模型
保存模型有两种最基本的方式:
1、保存整个网络:
torch.save(net, path1)
加载网络:
model = torch.load(path1)
2、保存模型参数:
torch.save(net.state_dict(), path2)
加载网络参数:
model.load_state_dict(torch.load(path2))
二、查看模型结构
PyTorch提供了两种查看模型结构的方法:
1、使用print()
函数打印模型结构:
print(model)
2、使用summary()
函数打印模型结构:
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
三、示例
以下是一个使用PyTorch保存和加载模型的示例:
import torch
# 定义网络
net = torch.nn.Sequential(
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10),
)
# 训练网络
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(10):
for x, y in data_loader:
optimizer.zero_grad()
y_pred = net(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
# 保存模型
torch.save(net.state_dict(), 'model.pth')
# 加载模型
model = torch.nn.Sequential(
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10),
)
model.load_state_dict(torch.load('model.pth'))
# 使用模型进行预测
x = torch.randn(1, 784)
y_pred = model(x)
四、总结
本教程介绍了如何使用PyTorch保存和加载模型,以及如何查看模型的结构。这些都是PyTorch中最基本的操作,对于深度学习实践非常重要。