state_dict

在PyTorch中,state_dict是一个字典对象,用于存储模型或优化器的参数。这个字典将每一层或优化器的参数映射到对应的张量。state_dict的主要作用在于方便模型的保存和加载,以便在训练过程中恢复模型的状态或在其他任务中重用模型。

  • 对于模型(如torch.nn.Module的实例),state_dict包含模型的可学习参数(如权重和偏置)。只有包含可学习参数的层(如卷积层、线性层等)和已注册的缓冲区(如Batch Normalization层的运行均值和方差)才会在state_dict中有对应的条目。这些参数是在模型训练过程中被优化器更新的。

  • 对于优化器(如torch.optim的实例),state_dict包含优化器的状态信息以及使用的超参数(如学习率、动量等)。这些状态信息用于在训练过程中更新模型的参数。

通过调用torch.save(model.state_dict(), PATH)可以将模型的state_dict保存到磁盘上,其中PATH是保存的路径。同样地,通过model.load_state_dict(torch.load(PATH))可以加载之前保存的state_dict到模型中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#encoding:utf-8

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F

#define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)

def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x

def main():
# Initialize model
model = TheModelClass()

#Initialize optimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

#print model's state_dict
print('Model.state_dict:')
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size())

#print optimizer's state_dict
print('Optimizer,s state_dict:')
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])



if __name__=='__main__':
main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Model.state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer`s state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

load_state_dict

load_state_dict 是 PyTorch 中 torch.nn.Module 类的一个方法,用于加载模型的状态字典(state dictionary)。状态字典是一个包含模型所有参数的字典,通常通过调用 state_dict() 方法获得。

参数说明

load_state_dict 的主要参数是一个字典,该字典包含了要加载的参数。通常,这个字典是通过 torch.load() 从一个文件(通常是 .pth.pt 文件)中加载的。

1
2
state_dict = torch.load('path_to_model.pth')
model.load_state_dict(state_dict)

注意事项

  1. 模型结构匹配:在调用 load_state_dict 之前,确保你已经定义了与保存状态字典时完全相同的模型结构。如果模型结构不匹配,你将无法加载状态字典,因为 PyTorch 无法将参数映射到正确的位置。
  2. 设备:加载的状态字典中的参数默认在 CPU 上。如果你想在 GPU 上使用这些参数,你需要先将模型移动到 GPU 上,然后再加载状态字典。
1
2
model = model.to('cuda')
model.load_state_dict(torch.load('path_to_model.pth'))

或者,你也可以在加载状态字典后移动模型:

1
2
model.load_state_dict(torch.load('path_to_model.pth', map_location=torch.device('cuda')))
model = model.to('cuda')
  1. 优化器状态:除了模型参数外,你可能还想加载优化器的状态。这可以通过类似的方式完成,但请注意,优化器的状态字典应该单独加载。
1
2
optimizer_state_dict = torch.load('path_to_optimizer.pth')
optimizer.load_state_dict(optimizer_state_dict)

model.train()

model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

当调用 model.train() 时,模型将处于训练模式,这通常意味着:

  1. Dropout:如果模型中包含 Dropout 层,那么在训练模式下,Dropout 层会在前向传播时随机地将一部分神经元的输出设置为零。这有助于防止模型过拟合。在评估模式下,Dropout 层不会进行任何操作,所有神经元的输出都会被保留。

  2. BatchNorm:Batch Normalization(BatchNorm)层在训练和评估模式下的行为也有所不同。在训练模式下,BatchNorm 会计算每个批次的均值和方差,并使用这些统计量来标准化输入。同时,它还会更新其内部运行均值和方差的估计值。在评估模式下,BatchNorm 会使用这些运行均值和方差来进行标准化,而不是每个批次的统计量。

  3. 其他层:有些自定义的层或模块也可能在训练和评估模式下有不同的行为。这取决于这些层或模块的实现。

使用示例

在训练循环的开始,你通常会调用 model.train() 来确保模型处于正确的模式:

1
2
3
4
5
6
7
8
9
10
11
12
model = MyModel()  # 假设 MyModel 是你的模型类
model.train() # 将模型设置为训练模式

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
for inputs, targets in dataloader:
optimizer.zero_grad() # 清零梯度
outputs = model(inputs) # 前向传播
loss = criterion(outputs, targets) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重

在评估或测试模型时,你应该使用 model.eval() 来确保模型不会应用 Dropout 或使用批次的统计量进行 BatchNorm:

1
2
3
4
5
6
model.eval()  # 将模型设置为评估模式

with torch.no_grad(): # 不计算梯度,节省内存和计算资源
for inputs, targets in test_dataloader:
outputs = model(inputs) # 前向传播
# ... 计算性能指标等 ...

请注意,在评估模式下,使用 with torch.no_grad(): 块是一个好习惯,因为它可以防止计算不必要的梯度,从而节省计算资源和内存。

model.eval()

model.eval() 是 PyTorch 框架中一个非常关键的方法,用于将模型设置为评估模式(evaluation mode)。在模型训练完成后,我们通常会对模型进行评估或测试以检查其性能。这时候,调用 model.eval() 是非常必要的,因为它会影响模型中的某些层(如 Dropout 和 Batch Normalization)的行为。

主要作用

  1. Dropout:在训练模式下,Dropout 层会随机丢弃一部分神经元的输出,这有助于防止模型过拟合。但在评估模式下,model.eval() 会关闭 Dropout 层的功能,确保所有神经元都参与前向传播,从而得到更稳定、更准确的输出。

  2. Batch Normalization:BatchNorm 层在训练和评估模式下的行为也不同。在训练模式下,BatchNorm 会使用当前批次的均值和方差来标准化输入。而在评估模式下,model.eval() 会指示 BatchNorm 使用训练过程中积累的运行均值(running mean)和运行方差(running variance)来进行标准化。这样做的好处是,模型在评估时对每个输入的标准化方式是一致的,不受批次大小的影响。

使用方法

在 PyTorch 中,使用 model.eval() 很简单。你只需在模型评估或测试之前调用它即可:

1
2
3
4
5
6
7
8
9
model = MyModel()  # 假设 MyModel 是你的模型类
model.load_state_dict(torch.load('path_to_model.pth')) # 加载预训练模型参数

model.eval() # 将模型设置为评估模式

with torch.no_grad(): # 不计算梯度,节省内存和计算资源
for inputs, targets in test_dataloader:
outputs = model(inputs) # 进行前向传播
# ... 计算性能指标等 ...

在上面的代码中,我们首先加载了预训练的模型参数,然后调用 model.eval() 将模型设置为评估模式。注意,我们还使用了 with torch.no_grad(): 块来确保在评估过程中不计算梯度,这有助于节省内存和计算资源。

注意事项

  1. 确保在正确的位置调用:确保在评估或测试开始前调用 model.eval(),并在训练开始前调用 model.train()。不要在训练循环内部多次调用 model.eval(),除非你有特定的需求。

  2. 梯度计算:调用 model.eval() 后,模型中的所有可学习参数的 requires_grad 属性将被设置为 False,这意味着在评估模式下不会计算梯度。这有助于加速推理过程。

  3. BatchNorm 和 Dropout 的固定:如前所述,model.eval() 会固定 BatchNorm 层和关闭 Dropout 层,确保在评估时模型的行为是一致的。

model.eval()和torch.no_grad()的区别

在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

  1. 主要用于通知dropout层和BN层在training和validation/test模式间切换:

    • 在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。
    • 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
  2. eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

Reference