使用 Apex 再加速
Apex 是 NVIDIA 开源的用于混合精度训练和分布式训练库。Apex 对混合精度训练的过程进行了封装,改两三行配置就可以进行混合精度的训练,从而大幅度降低显存占用,节约运算时间。此外,Apex 也提供了对分布式训练的封装,针对 NVIDIA 的 NCCL 通信库进行了优化。
在混合精度训练上,Apex 的封装十分优雅。直接使用 amp.initialize 包装模型和优化器,apex 就会自动帮助我们管理模型参数和优化器的精度了,根据精度需求不同可以传入其他配置参数。
1 2 3
| from apex import amp
model, optimizer = amp.initialize(model, optimizer)
|
在分布式训练的封装上,Apex 在胶水层的改动并不大,主要是优化了 NCCL 的通信。因此,大部分代码仍与 torch.distributed 保持一致。使用的时候只需要将 torch.nn.parallel.DistributedDataParallel 替换为 apex.parallel.DistributedDataParallel 用于包装模型。在 API 层面,相对于 torch.distributed ,它可以自动管理一些参数(可以少传一点):
1 2 3 4 5 6
| from apex.parallel import DistributedDataParallel
model = DistributedDataParallel(model)
|
在正向传播计算 loss 时,Apex 需要使用 amp.scale_loss 包装,用于根据 loss 值自动对精度进行缩放:
1 2
| with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()
|
汇总一下,Apex 的并行训练部分主要与如下代码段有关:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| import torch import argparse import torch.distributed as dist
from apex.parallel import DistributedDataParallel
parser = argparse.ArgumentParser() parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') args = parser.parse_args()
dist.init_process_group(backend='nccl') torch.cuda.set_device(args.local_rank)
train_dataset = ... train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
model = ... model, optimizer = amp.initialize(model, optimizer) model = DistributedDataParallel(model, device_ids=[args.local_rank])
optimizer = optim.SGD(model.parameters())
for epoch in range(100): for batch_idx, (data, target) in enumerate(train_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) ... output = model(images) loss = criterion(output, target) optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()
|
在使用时,调用 torch.distributed.launch 启动器启动:
1
| UDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
|
在 ImageNet 上的完整训练代码,请点击Github。