模型容器
模型容器
Torch.nn中一个重要的概念是模型容器 (Containers),常用的容器有 3 个,这些容器都是继承自nn.Module。
- nn.Sequetial:常用于 block 构建,按照顺序包装多个网络层;
- nn.ModuleList:常用于大量重复网络构建,通过 for 循环实现重复构建,像 python 的 list 一样包装多个网络层,可以迭代;
- nn.ModuleDict:常用于可选择的网络层,像 python 的 dict一样包装多个网络层,通过 (key, value) 的方式为每个网络层指定名称;
nn.Sequential
在传统的机器学习中,有一个步骤是特征工程,我们需要从数据中人为地提取特征,然后把特征输入到分类器中预测。在深度学习的时代,特征工程的概念被弱化了,特征提取和分类器这两步被融合到了一个神经网络中。在卷积神经网络中,前面的卷积层以及池化层可以认为是特征提取部分,而后面的全连接层可以认为是分类器部分。比如 LeNet 就可以分为特征提取和分类器两部分,这 2 部分都可以分别使用 nn.Seuqtial 来包装。
代码如下:
1 | |
在初始化时,nn.Sequetial会调用__init__()方法,将每一个子 module 添加到自身_modules属性中。这里可以看到,我们传入的参数可以是一个 list,或者一个 OrderDict。如果是一个 OrderDict,那么则使用 OrderDict 里的 key,否则使用数字作为 key (OrderDict 的情况会在下面提及)。
1 | |
网络初始化完成后有两个子 module:features和classifier。
而features中的子 module 如下,每个网络层以序号作为 key:
在进行前向传播时,会进入 LeNet 的forward()函数,首先调用第一个Sequetial容器:self.features,由于self.features也是一个 module,因此会调用__call__()函数,里面调用result = self.forward(*input, **kwargs),进入nn.Seuqetial的forward()函数,在这里依次调用所有的 module。
1 | |
在上面可以看到在nn.Sequetial中,里面的每个子网络层 module 是使用序号来索引的,即使用数字来作为 key。一旦网络层增多,难以查找特定的网络层,这种情况可以使用 OrderDict (有序字典)。代码中使用
1 | |
小结
nn.Sequetial是nn.Module的容器,用于按顺序包装一组网络层,有以下两个特性。
- 顺序性:各网络层之间严格按照顺序构建,我们在构建网络时,一定要注意前后网络层之间输入和输出数据之间的形状是否匹配
- 自带
forward()函数:在nn.Sequetial的forward()函数里通过 for 循环依次读取每个网络层,执行前向传播运算。这使得我们我们构建的模型更加简洁
nn.ModuleList
nn.ModuleList是nn.Module的容器,用于包装一组网络层,以迭代的方式调用网络层。ModuleList 接收一个子模块(或层,需属于nn.Module类)的列表作为输入,然后也可以类似List那样进行append和extend操作。同时,子模块或层的权重也会自动添加到网络中来。主要有以下 3 个方法:
- append():在 ModuleList 后面添加网络层
- extend():拼接两个 ModuleList
- insert():在 ModuleList 的指定位置中插入网络层
1 | |
1 | |
下面的代码通过列表生成式来循环迭代创建 20 个全连接层,非常方便,只是在 forward()函数中需要手动调用每个网络层。
1 | |
1 | |
要特别注意的是,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。ModuleList中元素的先后顺序并不代表其在网络中的真实位置顺序,需要经过forward函数指定各个层的先后顺序后才算完成了模型的定义。具体实现时用for循环即可完成:
1 | |
nn.ModuleDict
nn.ModuleDict是nn.Module的容器,用于包装一组网络层,以索引的方式调用网络层,ModuleDict能够更方便地为神经网络的层添加名称。主要有以下 5 个方法:
- clear():清空 ModuleDict
- items():返回可迭代的键值对 (key, value)
- keys():返回字典的所有 key
- values():返回字典的所有 value
- pop():返回一对键值,并从字典中删除
下面的模型创建了两个ModuleDict:self.choices和self.activations,在前向传播时通过传入对应的 key 来执行对应的网络层。
1 | |
1 | |
容器总结
nn.Sequetial:顺序性,各网络层之间严格按照顺序执行,常用于 block 构建,在前向传播时的代码调用变得简洁
nn.ModuleList:迭代行,常用于大量重复网络构建,通过 for 循环实现重复构建
nn.ModuleDict:索引性,常用于可选择的网络层






