使用 wandb 可视化训练过程

Tensorboard 对数据的保存仅限于本地,也很难分析超参数不同对实验的影响。wandb 的出现很好的解决了这些问题。wandb 是 Weights & Biases 的缩写,它能够自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。目前它能够和 Jupyter、TensorFlow、Pytorch、Keras、Scikit、fast.ai、LightGBM、XGBoost 一起结合使用。

wandb 的安装

wandb 的安装非常简单,我们只需要使用 pip 安装即可。

1
pip install wandb

安装完成后,我们需要在官网注册一个账号并复制下自己的 API keys,然后在本地使用下面的命令登录。

1
wandb login

这时,我们会看到下面的界面,只需要粘贴你的 API keys 即可。

wandb 的使用

wandb 的使用也非常简单,只需要在代码中添加几行代码即可。

1
2
import wandb
wandb.init(project='my-project', entity='my-name')

这里的 project 和 entity 是你在 wandb 上创建的项目名称和用户名,如果你还没有创建项目,可以参考官方文档

demo 演示

下面我们使用一个 CIFAR10 的图像分类 demo 来演示 wandb 的使用。

1
2
3
4
5
6
7
8
9
10
11
12

import random # to set the python random seed
import numpy # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings('ignore')

使用 wandb 的第一步是初始化 wandb,这里我们使用 wandb.init()函数来初始化 wandb,其中 project 是你在 wandb 上创建的项目名称,name 是你的实验名称。

1
2
3
4
# 初始化wandb
import wandb
wandb.init(project="thorough-pytorch",
name="wandb_demo",)

使用 wandb 的第二步是设置超参数,这里我们使用 wandb.config 来设置超参数,这样我们就可以在 wandb 的界面上看到超参数的变化。wandb.config 的使用方法和字典类似,我们可以使用 config.key 的方式来设置超参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 超参数设置
config = wandb.config # config的初始化
config.batch_size = 64
config.test_batch_size = 10
config.epochs = 5
config.lr = 0.01
config.momentum = 0.1
config.use_cuda = True
config.seed = 2043
config.log_interval = 10

# 设置随机数
def set_seed(seed):
random.seed(config.seed)
torch.manual_seed(config.seed)
numpy.random.seed(config.seed)

第三步是构建训练和测试的 pipeline,这里我们使用 pytorch 的 CIFAR10 数据集和 resnet18 来构建训练和测试的 pipeline。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def train(model, device, train_loader, optimizer):
model.train()

for batch_id, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
loss.backward()
optimizer.step()

# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):
model.eval()
test_loss = 0
correct = 0
example_images = []

with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
criterion = nn.CrossEntropyLoss()
test_loss += criterion(output, target).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
example_images.append(wandb.Image(
data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))

# 使用wandb.log 记录你想记录的指标
wandb.log({
"Examples": example_images,
"Test Accuracy": 100. * correct / len(test_loader.dataset),
"Test Loss": test_loss
})

wandb.watch_called = False


def main():
use_cuda = config.use_cuda and torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# 设置随机数
set_seed(config.seed)
torch.backends.cudnn.deterministic = True

# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据
train_loader = DataLoader(datasets.CIFAR10(
root='dataset',
train=True,
download=True,
transform=transform
), batch_size=config.batch_size, shuffle=True, **kwargs)

test_loader = DataLoader(datasets.CIFAR10(
root='dataset',
train=False,
download=True,
transform=transform
), batch_size=config.batch_size, shuffle=False, **kwargs)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

model = resnet18(pretrained=True).to(device)
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)

wandb.watch(model, log="all")
for epoch in range(1, config.epochs + 1):
train(model, device, train_loader, optimizer)
test(model, device, test_loader, classes)

# 本地和云端模型保存
torch.save(model.state_dict(), 'model.pth')
wandb.save('model.pth')


if __name__ == '__main__':
main()

当我们运行完上面的代码后,我们就可以在 wandb 的界面上看到我们的训练结果了和系统的性能指标。同时,我们还可以在 setting 里面设置训练完给我们发送邮件,这样我们就可以在训练完之后及时的查看训练结果了。


我们可以发现,使用 wandb 可以很方便的记录我们的训练结果,除此之外,wandb 还为我们提供了很多的功能,比如:模型的超参数搜索,模型的版本控制,模型的部署等等。这些功能都可以帮助我们更好的管理我们的模型,更好的进行模型的迭代和优化。