1. 导入必要的模块¶
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
import torch.nn.functional as F
2. 初始化 WandbLogger¶
- 替代手动
wandb.init:通过WandbLogger配置所有参数。 - 关键参数:
project:WandB 项目名称。name:实验名称(支持动态生成,如f"{args.method}_{args.dataset}")。group:实验分组(可选)。save_code:是否保存代码快照(默认True)。config:记录超参数(如args)。
wandb_logger = WandbLogger(
project="ESN",
name=f"{args.method}_{args.dataset}", # 动态名称示例
group=f"{args.dataset}_group",
notes="实验备注",
save_code=True,
config=args # 记录所有超参数
)
3. 配置 Trainer¶
- 关键点:将
WandbLogger实例传递给logger参数。 - 其他配置:如训练设备、最大轮次、回调等。
trainer = Trainer(
logger=wandb_logger, # 关键行:启用 WandB 日志
default_root_dir="./checkpoints/",
accelerator="gpu",
devices=1,
max_epochs=100,
callbacks=[
ModelCheckpoint(...), # 模型保存回调
LearningRateMonitor(...), # 学习率监控
]
)
4. 在 LightningModule 中记录损失¶
- 使用
self.log:在training_step中记录损失值。 - 避免名称冲突:使用自定义名称(如
train_loss_epoch)而非默认的train_loss。 - 参数一致性:确保所有
self.log调用对同一名称的参数一致。
class MyModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# 记录训练损失(自定义名称)
self.log(
"train_loss_epoch", # 唯一名称,避免冲突
loss,
on_step=False, # 不记录每一步
on_epoch=True, # 记录每个 epoch 的平均值
prog_bar=True, # 在进度条显示
logger=True # 同步到 WandB
)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
5. 运行训练并验证¶
- 启动训练:
python
model = MyModel()
trainer.fit(model, train_dataloader)
-
检查 WandB:
-
登录 WandB 网页端,查看对应项目。
- 在 Metrics 面板查看
train_loss_epoch曲线。 - 在 Overview 面板查看超参数和代码快照。
常见问题排查¶
- 指标未显示:
- 确认logger=True已设置。
- 检查WandbLogger是否传递给Trainer。 - 名称冲突报错:
- 确保所有self.log调用中名称唯一且参数一致。
- 避免使用 Lightning 的默认名称(如train_loss)。 - WandB 无数据:
- 运行wandb login确保已登录。
- 检查网络连接或 API 密钥有效性。
最终效果¶
- 自动记录:每个 epoch 的
train_loss_epoch自动同步到 WandB。 - 交互式曲线:WandB 自动生成可缩放、可筛选的损失曲线。
- 实验管理:所有超参数和代码版本均被记录,确保实验可复现。
评论