返回

PyTorch保存和加载模型以及查看模型结构的方法,通俗易懂的入门指南

人工智能

一、保存和加载模型

保存模型有两种最基本的方式:

1、保存整个网络:

torch.save(net, path1)

加载网络:

model = torch.load(path1)

2、保存模型参数:

torch.save(net.state_dict(), path2)

加载网络参数:

model.load_state_dict(torch.load(path2))

二、查看模型结构

PyTorch提供了两种查看模型结构的方法:

1、使用print()函数打印模型结构:

print(model)

2、使用summary()函数打印模型结构:

from torchsummary import summary
summary(model, input_size=(3, 224, 224))

三、示例

以下是一个使用PyTorch保存和加载模型的示例:

import torch

# 定义网络
net = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10),
)

# 训练网络
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(10):
    for x, y in data_loader:
        optimizer.zero_grad()
        y_pred = net(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

# 保存模型
torch.save(net.state_dict(), 'model.pth')

# 加载模型
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10),
)
model.load_state_dict(torch.load('model.pth'))

# 使用模型进行预测
x = torch.randn(1, 784)
y_pred = model(x)

四、总结

本教程介绍了如何使用PyTorch保存和加载模型,以及如何查看模型的结构。这些都是PyTorch中最基本的操作,对于深度学习实践非常重要。