nn.Module
state_dict
在PyTorch中,
state_dict是一个字典对象,用于存储模型或优化器的参数。这个字典将每一层或优化器的参数映射到对应的张量。state_dict的主要作用在于方便模型的保存和加载,以便在训练过程中恢复模型的状态或在其他任务中重用模型。
对于模型(如
torch.nn.Module的实例),state_dict包含模型的可学习参数(如权重和偏置)。只有包含可学习参数的层(如卷积层、线性层等)和已注册的缓冲区(如Batch Normalization层的运行均值和方差)才会在state_dict中有对应的条目。这些参数是在模型训练过程中被优化器更新的。对于优化器(如
torch.optim的实例),state_dict包含优化器的状态信息以及使用的超参数(如学习率、动量等)。这些状态信息用于在训练过程中更新模型的参数。
通过调用torch.save(model.state_dict(), PATH)可以将模型的state_dict保存到磁盘上,其中PATH是保存的路径。同样地,通过model.load_state_dict(torch.load(PATH))可以加载之前保存的state_dict到模型中。
1 | |
1 | |
load_state_dict
load_state_dict是 PyTorch 中torch.nn.Module类的一个方法,用于加载模型的状态字典(state dictionary)。状态字典是一个包含模型所有参数的字典,通常通过调用state_dict()方法获得。
参数说明
load_state_dict 的主要参数是一个字典,该字典包含了要加载的参数。通常,这个字典是通过 torch.load() 从一个文件(通常是 .pth 或 .pt 文件)中加载的。
1 | |
注意事项
- 模型结构匹配:在调用
load_state_dict之前,确保你已经定义了与保存状态字典时完全相同的模型结构。如果模型结构不匹配,你将无法加载状态字典,因为 PyTorch 无法将参数映射到正确的位置。 - 设备:加载的状态字典中的参数默认在 CPU 上。如果你想在 GPU 上使用这些参数,你需要先将模型移动到 GPU 上,然后再加载状态字典。
1 | |
或者,你也可以在加载状态字典后移动模型:
1 | |
- 优化器状态:除了模型参数外,你可能还想加载优化器的状态。这可以通过类似的方式完成,但请注意,优化器的状态字典应该单独加载。
1 | |
model.train()
model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

当调用 model.train() 时,模型将处于训练模式,这通常意味着:
Dropout:如果模型中包含 Dropout 层,那么在训练模式下,Dropout 层会在前向传播时随机地将一部分神经元的输出设置为零。这有助于防止模型过拟合。在评估模式下,Dropout 层不会进行任何操作,所有神经元的输出都会被保留。
BatchNorm:Batch Normalization(BatchNorm)层在训练和评估模式下的行为也有所不同。在训练模式下,BatchNorm 会计算每个批次的均值和方差,并使用这些统计量来标准化输入。同时,它还会更新其内部运行均值和方差的估计值。在评估模式下,BatchNorm 会使用这些运行均值和方差来进行标准化,而不是每个批次的统计量。
其他层:有些自定义的层或模块也可能在训练和评估模式下有不同的行为。这取决于这些层或模块的实现。
使用示例
在训练循环的开始,你通常会调用 model.train() 来确保模型处于正确的模式:
1 | |
在评估或测试模型时,你应该使用 model.eval() 来确保模型不会应用 Dropout 或使用批次的统计量进行 BatchNorm:
1 | |
请注意,在评估模式下,使用 with torch.no_grad(): 块是一个好习惯,因为它可以防止计算不必要的梯度,从而节省计算资源和内存。
model.eval()

model.eval() 是 PyTorch 框架中一个非常关键的方法,用于将模型设置为评估模式(evaluation mode)。在模型训练完成后,我们通常会对模型进行评估或测试以检查其性能。这时候,调用 model.eval() 是非常必要的,因为它会影响模型中的某些层(如 Dropout 和 Batch Normalization)的行为。
主要作用
Dropout:在训练模式下,Dropout 层会随机丢弃一部分神经元的输出,这有助于防止模型过拟合。但在评估模式下,
model.eval()会关闭 Dropout 层的功能,确保所有神经元都参与前向传播,从而得到更稳定、更准确的输出。Batch Normalization:BatchNorm 层在训练和评估模式下的行为也不同。在训练模式下,BatchNorm 会使用当前批次的均值和方差来标准化输入。而在评估模式下,
model.eval()会指示 BatchNorm 使用训练过程中积累的运行均值(running mean)和运行方差(running variance)来进行标准化。这样做的好处是,模型在评估时对每个输入的标准化方式是一致的,不受批次大小的影响。
使用方法
在 PyTorch 中,使用 model.eval() 很简单。你只需在模型评估或测试之前调用它即可:
1 | |
在上面的代码中,我们首先加载了预训练的模型参数,然后调用 model.eval() 将模型设置为评估模式。注意,我们还使用了 with torch.no_grad(): 块来确保在评估过程中不计算梯度,这有助于节省内存和计算资源。
注意事项
确保在正确的位置调用:确保在评估或测试开始前调用
model.eval(),并在训练开始前调用model.train()。不要在训练循环内部多次调用model.eval(),除非你有特定的需求。梯度计算:调用
model.eval()后,模型中的所有可学习参数的requires_grad属性将被设置为False,这意味着在评估模式下不会计算梯度。这有助于加速推理过程。BatchNorm 和 Dropout 的固定:如前所述,
model.eval()会固定 BatchNorm 层和关闭 Dropout 层,确保在评估时模型的行为是一致的。
model.eval()和torch.no_grad()的区别
在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:
主要用于通知dropout层和BN层在training和validation/test模式间切换:
- 在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。
- 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。



