timm-vit代码解读
视觉 Transformer 优秀开源工作:timm 库 vision transformer 代码解读
timm库 vision_transformer.py代码解读
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
vision_transformer.py
代码中定义的变量的含义如下:
img_size:tuple类型,里面是int类型,代表输入的图片大小,默认是224。
patch_size:tuple类型,里面是int类型,代表Patch的大小,默认是16。
in_chans:int类型,代表输入图片的channel数,默认是3。
num_classes:int类型classification head的分类数,比如CIFAR100就是100,默认是1000。
embed_dim:int类型Transformer的embedding dimension,默认是768。
depth:int类型,Transformer的Block的数量,默认是12。
num_heads:int类型,attention heads的数量,默认是12。
mlp_ratio:int类型,mlp hidden dim/embedding dim的值,默认是4。
qkv_bias:bool类型,attention模块计算qkv时需要bias吗,默认是True。
qk_scale:一般设置成None就行。
drop_rate:float类型,dropout rate,默认是0。
attn_drop_rate:float类型,attention模块的dropout rate,默认是0。
drop_path_rate:float类型,默认是0。
hybrid_backbone:nn.Module类型,在把图片转换成Patch之前,需要先通过一个Backbone吗?默认是None。
如果是None,就直接把图片转化成Patch。
如果不是None,就先通过这个Backbone,再转化成Patch。
norm_layer:nn.Module类型,归一化层类型,默认是None。
1 导入必要的库和模型:
1 | |
2 定义字典,代表标准的模型
如果需要更改模型超参数只需要改变_cfg的传入的参数即可。
1 | |
3 default_cfgs
default_cfgs代表支持的所有模型,也定义成字典的形式:
vit_small_patch16_224里面的small代表小模型。
ViT的第一步要把图片分成一个个patch,然后把这些patch组合在一起作为对图像的序列化操作,比如一张224 × 224的图片分成大小为16 × 16的patch,那一共可以分成196个。所以这个图片就序列化成了(196, 256)的tensor。所以这里的:
16:就代表patch的大小。
224:就代表输入图片的大小。
按照这个命名方式,支持的模型有:vit_base_patch16_224,vit_base_patch16_384等等。后面的vit_deit_base_patch16_224等等模型代表DeiT这篇论文的模型。
1 | |
4 FFN实现
1 | |
5 Attention实现
在python 3.5以后,@是一个操作符,表示矩阵-向量乘法
A@x 就是矩阵-向量乘法A*x: http://np.dot/(A, x)。
1 | |
6 包含Attention和Add & Norm的Block实现

不同之处是:
先进行Norm,再Attention;先进行Norm,再通过FFN (MLP)。
1 | |
7 图片转换成Patch
一种做法是直接把Image转化成Patch,另一种做法是把Backbone输出的特征转化成Patch。
1) 直接把Image转化成Patch:
输入的x的维度是:(B, C, H, W)
输出的PatchEmbedding的维度是:(B, 14*14, 768),768表示embed_dim,14*14表示一共有196个Patches。
1 | |
2) 把Backbone输出的特征转化成Patch:
输入的x的维度是:(B, C, H, W)
得到Backbone输出的维度是:(B, feature_size, feature_size, feature_dim)
输出的PatchEmbedding的维度是:(B, feature_size, feature_size, embed_dim),一共有feature_size * feature_size个Patches。
1 | |
8 VisionTransformer 类的实现
8.1 传入的变量
1 | |
8.2 获取Patch的数量
1 | |
8.3 class token:
一开始定义成(1, 1, 768),之后再变成(B, 1, 768)。
1 | |
8.4 定义位置编码
1 | |
8.6 表示层和分类头
表示层输出维度是representation_size,分类头输出维度是num_classes。
1 | |
8.7 初始化各个模块
函数trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.)的目的是用截断的正态分布绘制的值填充输入张量,我们只需要输入均值mean,标准差std,下界a,上界b即可。
self.apply(self._init_weights)表示对各个模块的权重进行初始化。apply函数的代码是:
1 | |
递归地将fn应用于每个子模块,相当于在递归调用fn,即_init_weights这个函数。
也就是把模型的所有子模块的nn.Linear和nn.LayerNorm层都初始化掉。
1 | |
8.8 forward实现
1 | |
9 下面是Training data-efficient image transformers & distillation through attention这篇论文的DeiT这个类的实现:
整体结构与ViT相似,继承了上面的VisionTransformer类。
1 | |
再额外定义以下3个变量:
- distillation token:dist_token
- 新的位置编码:pos_embed
- 蒸馏分类头:head_dist
DeiT相关介绍可以参考:https://zhuanlan.zhihu.com/p/348593638。
1 | |
初始化新定义的变量:
1 | |
前向函数:
1 | |
10 对位置编码进行插值:
posemb代表未插值的位置编码权值,posemb_tok为位置编码的token部分,posemb_grid为位置编码的插值部分。
首先把要插值部分posemb_grid给reshape成(1, gs_old, gs_old, -1)的形式,再插值成(1, gs_new, gs_new, -1)的形式,最后与token部分在第1维度拼接在一起,得到插值后的位置编码posemb。
1 | |
11 _create_vision_transformer函数用于创建vision transformer:
checkpoint_filter_fn的作用是加载预训练权重。
1 | |
12 定义和注册vision transformer模型:
@ register_model这个函数来自timm库model文件夹下的registry.py文件,它的作用是:
@ 指装饰器
@register_model代表注册器,注册这个新定义的模型。
存储到**_model_entrypoints**这个字典中,比如:
1 | |
然后在factory.py的create_model函数中的下面这几行真正创建模型,你以后想创建的任何模型都会使用create_model这个函数,这里说清楚了为什么要用它:
1 | |
比如刚才在main.py里面用了create_model创建模型,如下面代码所示。而create_model就来自factory.py:
1 | |
一共可以选择的模型包括:
ViT系列:
vit_small_patch16_224
vit_base_patch16_224
vit_base_patch32_224
vit_base_patch16_384
vit_base_patch32_384
vit_large_patch16_224
vit_large_patch32_224
vit_large_patch16_384
vit_large_patch32_384
vit_base_patch16_224_in21k
vit_base_patch32_224_in21k
vit_large_patch16_224_in21k
vit_large_patch32_224_in21k
vit_huge_patch14_224_in21k
vit_base_resnet50_224_in21k
vit_base_resnet50_384
vit_small_resnet26d_224
vit_small_resnet50d_s3_224
vit_base_resnet26d_224
vit_base_resnet50d_224DeiT系列:
vit_deit_tiny_patch16_224
vit_deit_small_patch16_224
vit_deit_base_patch16_224
vit_deit_base_patch16_384
vit_deit_tiny_distilled_patch16_224
vit_deit_small_distilled_patch16_224
vit_deit_base_distilled_patch16_224
vit_deit_base_distilled_patch16_384
以上就是对timm库 vision_transformer.py代码的分析。
4 如何使用timm库以及 vision_transformer.py代码搭建自己的模型?
在搭建我们自己的视觉Transformer模型时,我们可以按照下面的步骤操作:首先
- 继承timm库的VisionTransformer这个类。
- 添加上自己模型独有的一些变量。
- 重写forward函数。
- 通过timm库的注册器注册新模型。
我们以ViT模型的改进版DeiT为例:
首先,DeiT的所有模型列表如下:
1 | |
导入VisionTransformer这个类,注册器register_model,以及初始化函数trunc_normal_:
1 | |
DeiT的class名称是DistilledVisionTransformer,它直接继承了VisionTransformer这个类:
1 | |
添加上自己模型独有的一些变量:
1 | |
重写forward函数:
1 | |
通过timm库的注册器注册新模型:
1 | |
———————————————— 更新 2021.03.01———————————————
5 timm库 train.py代码解读:
timm库的训练使用 结合apex支持的分布式训练,同步bn,以及混合精度的训练方式,其train.py的写法很具有代表性,值得拿出来讨论。因此这篇文章再多加一段,来专门讨论这个train.py。
结合apex支持的分布式训练,同步bn,以及混合精度的训练方式的详细讲解可以参考下面这篇文章:
科技猛兽:PyTorch 77.结合apex支持的分布式训练,同步bn,以及混合精度在这篇文章中我们使用8步法结合apex支持的分布式训练,同步bn,以及混合精度:
1. 先罗列自己网络的参数:
1 | |
local_rank指定了输出设备,默认为GPU可用列表中的第一个GPU。这里这个是必须加的。原因后面讲
2. 在主函数中开头写:
1 | |
3. 导入数据接口,这里有一点不一样。需要用一个DistributedSampler:
1 | |
4. 之后定义模型:
1 | |
5. 定义优化器,损失函数,定义优化器一定要在把模型搬运到GPU之后:
1 | |
6. 多GPU设置:
1 | |
7. 记得loss要这么用:
1 | |
8. 然后在代码底部加入:
1 | |
那么这个train.py大体上依然遵循这8步:
https://github.com/rwightman/pytorch-image-models/blob/master/train.py1. 通过命令行解析定义各种超参数,包括:
Dataset / Model parameters,比如:data,–model,–pretrained等等。
Optimizer parameters,比如:–opt,–opt-eps,–momentum等等。
Learning rate schedule parameters,比如:–sched,–lr,–epochs,–start-epoch,–decay-epochs,–decay-rate等等。
Augmentation & regularization parameters,比如:–mixup,–hflip,–vflip,–cutmix,–drop等等。
Batch norm parameters,比如:–bn-tf,–bn-momentum,–sync-bn,–dist-bn,–split-bn等等。
Model Exponential Moving Average parameters,比如:–model-ema,–model-ema-force-cpu,–model-ema-decay等等。
Misc parameters,比如:–seed,–log-interval,–num-gpu,–save-images,amp,–apex-amp,–native-amp,–output,–local_rank等等。
1 | |
2. 分布式命令:
1 | |
3. 导入数据接口,这里有一点不一样。需要用一个DistributedSampler:
1 | |
4. 之后定义模型:
1 | |
5. 定义优化器,损失函数,定义优化器一定要在把模型搬运到GPU之后:
1 | |
6. 多GPU设置:
1 | |
7. 记得loss要这么用:
1 | |
8. 然后在代码底部加入:
1 | |
总结
本文简要介绍了优秀的PyTorch Image Model 库:timm库以及其中的 vision transformer 代码和训练代码。 Transformer 架构早已在自然语言处理任务中得到广泛应用,但在计算机视觉领域中仍然受到限制。在计算机视觉领域,目前已有大量工作表明模型对 CNN 的依赖不是必需的,当直接应用于图像块序列时,transformer 也能很好地执行图像分类任务。本文的目的是为学者介绍一个优秀的 vision transformer 的PyTorch实现,以便更快地开展相关实验。



