概述
使用 PyTorch Lightning 进行模型训练可以简化深度学习项目的开发流程,提高代码的可读性和可维护性。以下是使用 PyTorch Lightning 完成模型训练的主要步骤:
安装 PyTorch Lightning
首先,确保已安装 PyTorch Lightning。可以使用以下命令通过 pip 安装:
1
pip install pytorch-lightning定义 LightningModule
创建一个继承自
pl.LightningModule的类,用于定义模型结构、前向传播、损失计算和优化器配置等。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super(LitModel, self).__init__()
self.layer = nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.layer(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)准备数据
使用
LightningDataModule或自定义数据加载器来处理数据的加载和预处理。1
2
3
4
5
6
7
8
9
10from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = MNIST(root='data', train=True, download=True, transform=transform)
train_set, val_set = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_set, batch_size=64, num_workers=4)
val_loader = DataLoader(val_set, batch_size=64, num_workers=4)初始化 Trainer 并开始训练
使用
pl.Trainer初始化训练器,设置训练参数,并调用fit方法开始训练。1
2
3
4
5from pytorch_lightning import Trainer
model = LitModel()
trainer = Trainer(max_epochs=10, gpus=1)
trainer.fit(model, train_loader, val_loader)模型验证和测试
在训练完成后,可以使用验证集或测试集评估模型性能。
1
2
3test_set = MNIST(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=64, num_workers=4)
trainer.test(model, test_loader)
通过以上步骤,您可以使用 PyTorch Lightning 高效地完成模型的训练和评估过程。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Hongwei Zhao's Blog!



