概述

PyTorchImageModels,简称 timm,是一个巨大的 PyTorch 代码集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts

旨在将各种 SOTA 模型整合在一起,并具有复现 ImageNet 训练结果的能力。虽然模型架构是 timm 的重点,但它还包括许多数据增强 (data augmentations)、正则化技术 (regularization techniques)、优化器 (optimizers) 和学习率策略 (learning rate schedulers) 的实现。

作者github链接:rwightman - Overview
timm库链接:rwightman/pytorch-image-models

timm库特点

所有的模型都有默认的API

  • accessing/changing the classifier - get_classifier and reset_classifier
  • 只对features做前向传播 - forward_features

所有模型都支持多尺度特征提取 (feature pyramids) (通过create_model函数)

  • create_model(name, features_only=True, out_indices=..., output_stride=...)

out_indices 指定返回哪个feature maps to return, 从0开始,out_indices[i]对应着 C(i + 1) feature level。

output_stride 通过dilated convolutions控制网络的output stride。大多数网络默认 stride 32 。

所有的模型都有一致的pretrained weight loader,adapts last linear if necessary。

训练方式支持

  • NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)

  • PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)

  • PyTorch w/ single GPU single process (AMP optional)

动态的全局池化方式可以选择:average pooling, max pooling, average + max, or concat([average, max]),默认是adaptive average。

Schedulers

Schedulers 包括step,cosinew/ restarts,tanhw/ restarts,plateau

Optimizer