1. 导入必要的模块

1
2
3
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
import torch.nn.functional as F

2. 初始化 WandbLogger

  • 替代手动 wandb.init:通过 WandbLogger 配置所有参数。
  • 关键参数
    • project:WandB 项目名称。
    • name:实验名称(支持动态生成,如 f"{args.method}_{args.dataset}")。
    • group:实验分组(可选)。
    • save_code:是否保存代码快照(默认 True)。
    • config:记录超参数(如 args)。
1
2
3
4
5
6
7
8
wandb_logger = WandbLogger(
project="ESN",
name=f"{args.method}_{args.dataset}", # 动态名称示例
group=f"{args.dataset}_group",
notes="实验备注",
save_code=True,
config=args # 记录所有超参数
)

3. 配置 Trainer

  • 关键点:将 WandbLogger 实例传递给 logger 参数。
  • 其他配置:如训练设备、最大轮次、回调等。
1
2
3
4
5
6
7
8
9
10
11
trainer = Trainer(
logger=wandb_logger, # 关键行:启用 WandB 日志
default_root_dir="./checkpoints/",
accelerator="gpu",
devices=1,
max_epochs=100,
callbacks=[
ModelCheckpoint(...), # 模型保存回调
LearningRateMonitor(...), # 学习率监控
]
)

4. 在 LightningModule 中记录损失

  • 使用 self.log:在 training_step 中记录损失值。
  • 避免名称冲突:使用自定义名称(如 train_loss_epoch)而非默认的 train_loss
  • 参数一致性:确保所有 self.log 调用对同一名称的参数一致。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MyModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)

# 记录训练损失(自定义名称)
self.log(
"train_loss_epoch", # 唯一名称,避免冲突
loss,
on_step=False, # 不记录每一步
on_epoch=True, # 记录每个 epoch 的平均值
prog_bar=True, # 在进度条显示
logger=True # 同步到 WandB
)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)

5. 运行训练并验证

  • 启动训练

    1
    2
    model = MyModel()
    trainer.fit(model, train_dataloader)
  • 检查 WandB

    • 登录 WandB 网页端,查看对应项目。
    • Metrics 面板查看 train_loss_epoch 曲线。
    • Overview 面板查看超参数和代码快照。

常见问题排查

  1. 指标未显示
    • 确认 logger=True 已设置。
    • 检查 WandbLogger 是否传递给 Trainer
  2. 名称冲突报错
    • 确保所有 self.log 调用中名称唯一且参数一致。
    • 避免使用 Lightning 的默认名称(如 train_loss)。
  3. WandB 无数据
    • 运行 wandb login 确保已登录。
    • 检查网络连接或 API 密钥有效性。

最终效果

  • 自动记录:每个 epoch 的 train_loss_epoch 自动同步到 WandB。
  • 交互式曲线:WandB 自动生成可缩放、可筛选的损失曲线。
  • 实验管理:所有超参数和代码版本均被记录,确保实验可复现。