计算数据集的均值和标准差
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| import os import cv2 import numpy as np from torch.utils.data import Dataset from PIL import Image
def compute_mean_and_std(dataset): mean_r = 0 mean_g = 0 mean_b = 0
for img, _ in dataset: img = np.asarray(img) mean_b += np.mean(img[:, :, 0]) mean_g += np.mean(img[:, :, 1]) mean_r += np.mean(img[:, :, 2])
mean_b /= len(dataset) mean_g /= len(dataset) mean_r /= len(dataset)
diff_r = 0 diff_g = 0 diff_b = 0
N = 0
for img, _ in dataset: img = np.asarray(img)
diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2)) diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2)) diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))
N += np.prod(img[:, :, 0].shape)
std_b = np.sqrt(diff_b / N) std_g = np.sqrt(diff_g / N) std_r = np.sqrt(diff_r / N)
mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0) std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0) return mean, std
|
得到视频数据基本信息
1 2 3 4 5 6 7
| import cv2 video = cv2.VideoCapture(mp4_path) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) video.release()
|
TSN 每段(segment)采样一帧视频
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| K = self._num_segments if is_train: if num_frames > K: frame_indices = torch.randint( high=num_frames // K, size=(K,), dtype=torch.long) frame_indices += num_frames // K * torch.arange(K) else: frame_indices = torch.randint( high=num_frames, size=(K - num_frames,), dtype=torch.long) frame_indices = torch.sort(torch.cat(( torch.arange(num_frames), frame_indices)))[0] else: if num_frames > K: frame_indices = num_frames / K // 2 frame_indices += num_frames // K * torch.arange(K) else: frame_indices = torch.sort(torch.cat(( torch.arange(num_frames), torch.arange(K - num_frames))))[0] assert frame_indices.size() == (K,) return [frame_indices[i] for i in range(K)]
|
常用训练和验证数据预处理
其中 ToTensor 操作会将 PIL.Image 或形状为 $H\times W\times D$,数值范围为 [0, 255] 的 np.ndarray 转换为形状为 $D\times H\times W$,数值范围为 [0.0, 1.0] 的 torch.Tensor。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| train_transform = torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) val_transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ])
|
BN操作
多卡同步 BN(Batch normalization)
当使用 torch.nn.DataParallel 将代码运行在多张 GPU 卡上时,PyTorch 的 BN 层默认操作是各卡上数据独立地计算均值和标准差,同步 BN 使用所有卡上的数据一起计算 BN 层的均值和标准差,缓解了当批量大小(batch size)比较小时对均值和标准差估计不准的情况,是在目标检测等任务中一个有效的提升性能的技巧。
1 2 3 4 5 6
| sync_bn = torch.nn.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
|
将已有网络的所有 BN 层改为同步 BN 层
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| def convertBNtoSyncBN(module, process_group=None): '''Recursively replace all BN layers to SyncBN layer.
Args: module[torch.nn.Module]. Network ''' if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): sync_bn = torch.nn.SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group) sync_bn.running_mean = module.running_mean sync_bn.running_var = module.running_var if module.affine: sync_bn.weight = module.weight.clone().detach() sync_bn.bias = module.bias.clone().detach() return sync_bn else: for name, child_module in module.named_children(): setattr(module, name) = convert_syncbn_model(child_module, process_group=process_group)) return module
|
类似 BN 滑动平均
如果要实现类似 BN 滑动平均的操作,在 forward 函数中要使用原地(inplace)操作给滑动平均赋值。
1 2 3 4 5 6 7 8
| class BN(torch.nn.Module) def __init__(self): ... self.register_buffer('running_mean', torch.zeros(num_features))
def forward(self, X): ... self.running_mean += momentum * (current - self.running_mean)
|