计算数据集的均值和标准差

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Image

def compute_mean_and_std(dataset):
# 输入PyTorch的dataset,输出均值和标准差
mean_r = 0
mean_g = 0
mean_b = 0

for img, _ in dataset:
img = np.asarray(img) # change PIL Image to numpy array
mean_b += np.mean(img[:, :, 0])
mean_g += np.mean(img[:, :, 1])
mean_r += np.mean(img[:, :, 2])

mean_b /= len(dataset)
mean_g /= len(dataset)
mean_r /= len(dataset)

diff_r = 0
diff_g = 0
diff_b = 0

N = 0

for img, _ in dataset:
img = np.asarray(img)

diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2))
diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))
diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))

N += np.prod(img[:, :, 0].shape)

std_b = np.sqrt(diff_b / N)
std_g = np.sqrt(diff_g / N)
std_r = np.sqrt(diff_r / N)

mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0)
std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0)
return mean, std

得到视频数据基本信息

1
2
3
4
5
6
7
import cv2
video = cv2.VideoCapture(mp4_path)
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
video.release()

TSN 每段(segment)采样一帧视频

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
K = self._num_segments
if is_train:
if num_frames > K:
# Random index for each segment.
frame_indices = torch.randint(
high=num_frames // K, size=(K,), dtype=torch.long)
frame_indices += num_frames // K * torch.arange(K)
else:
frame_indices = torch.randint(
high=num_frames, size=(K - num_frames,), dtype=torch.long)
frame_indices = torch.sort(torch.cat((
torch.arange(num_frames), frame_indices)))[0]
else:
if num_frames > K:
# Middle index for each segment.
frame_indices = num_frames / K // 2
frame_indices += num_frames // K * torch.arange(K)
else:
frame_indices = torch.sort(torch.cat((
torch.arange(num_frames), torch.arange(K - num_frames))))[0]
assert frame_indices.size() == (K,)
return [frame_indices[i] for i in range(K)]

常用训练和验证数据预处理

其中 ToTensor 操作会将 PIL.Image 或形状为 $H\times W\times D$,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 $D\times H\times W$,数值范围为 [0.0, 1.0] 的 torch.Tensor。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
train_transform = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(size=224,
scale=(0.08, 1.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])
val_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])

BN操作

多卡同步 BN(Batch normalization)

当使用 torch.nn.DataParallel 将代码运行在多张 GPU 卡上时,PyTorch 的 BN 层默认操作是各卡上数据独立地计算均值和标准差,同步 BN 使用所有卡上的数据一起计算 BN 层的均值和标准差,缓解了当批量大小(batch size)比较小时对均值和标准差估计不准的情况,是在目标检测等任务中一个有效的提升性能的技巧。

1
2
3
4
5
6
sync_bn = torch.nn.SyncBatchNorm(num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)

将已有网络的所有 BN 层改为同步 BN 层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def convertBNtoSyncBN(module, process_group=None):
'''Recursively replace all BN layers to SyncBN layer.

Args:
module[torch.nn.Module]. Network
'''
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
sync_bn = torch.nn.SyncBatchNorm(module.num_features, module.eps, module.momentum,
module.affine, module.track_running_stats, process_group)
sync_bn.running_mean = module.running_mean
sync_bn.running_var = module.running_var
if module.affine:
sync_bn.weight = module.weight.clone().detach()
sync_bn.bias = module.bias.clone().detach()
return sync_bn
else:
for name, child_module in module.named_children():
setattr(module, name) = convert_syncbn_model(child_module, process_group=process_group))
return module

类似 BN 滑动平均

如果要实现类似 BN 滑动平均的操作,在 forward 函数中要使用原地(inplace)操作给滑动平均赋值。

1
2
3
4
5
6
7
8
class BN(torch.nn.Module)
def __init__(self):
...
self.register_buffer('running_mean', torch.zeros(num_features))

def forward(self, X):
...
self.running_mean += momentum * (current - self.running_mean)