LST: Ladder Side-Tuning for Parameter and Memory Efficient Transfer Learning

在简述 LST 方法之前,我们先思考一下,上述的 Adapter 方法和 P-Tuning 方法在训练微调时到底高效在哪?

我们知道在对模型训练时,我们姑且可以笼统的分为两步:反向传播和梯度下降,以 Pytorch 为例具体就是

1
2
loss.backward()  #反向传播
optimizer.step() #梯度下降
  • backward()会根据模型的前向推理计算图来反向的对各个 layer 中的 weight 求的偏导$\frac{\partial L}{\partial w}$ 于之后的梯度下降,对各个 layer 中的 input 求偏导$\frac{\partial L}{\partial x}$用于向前层传递梯度,链式求导。

  • step()则会根据所选用的优化器,对需要训练的参数执行相应的梯度下降策略,我们姑且可以将此过程简单描述成 如下公式 , 其中$\eta$ 学习率

$$
w=w-\eta\frac{\partial L}{\partial w}
$$

LST 直接在原模型的推理旁路上加了一个新的分支,也是固定原模型中的参数,将原模型各层输出与新建的旁路分支结合得到输出。

20240205212804

从图中不难发现,三种微调方法:

对于 Adapter 来说虽然只训练新插入的少部分参数,但是整个梯度回传的过程不能省略,换句话说,与微调整个模型相比:1)对于反向传播过程而言,各层对 weight 的梯度$\frac{\partial L}{\partial w}$ 用算了,但是对于 input 的梯度$\frac{\partial L}{\partial x}$ 算(要向前层传递); 2)对于梯度下降过程而言,只需要下降少部分新插入的层的 weight,原模型的 weight 都不动
对于 P-Tuning 虽然只需要训练 Embedding 层,但是 Embedding 层是输入层,所以与微调整个模型相比:1)对于反向传播过程而言,各层对 weight 的梯度$\frac{\partial L}{\partial w}$ 用算了,但是对于 input 的梯度$\frac{\partial L}{\partial x}$ 算(要向前层传递); 2)对于梯度下降过程而言,只需要下降 Embedding 层参数,原模型的 weight 都不动
所以我们再来回答本小节开始的问题:高效微调的高效 ,我觉得在于减少了所需梯度下降过程的权重量和计算量,对于反向传播的过程,不需要保存对原始 weight 得梯度也就节省了显存,但是反向传播的复杂度并没有降低。

而对于 LST 便是一种反向传播过程和梯度下降过程都高效的微调方法,如上图(c)而言,不难发现,LST 的反向传播和梯度下降过程都与原始模型无关,相当于我重新定义了一个小的模型结构,通过获取原模型的输出作为输入来协助微调最终的结果

对比

微调技术 对virtual token的处理 参与微调的参数 反向传播 梯度下降 推理延迟
Adapter 不用求原 W 梯度,得求全层的 X 梯度 下降少量新增网络的 W 会增加
P-Tuning MLP+LSTM/MLP 不用求原 W 梯度,得求全层的 X 梯度 下降 Embedding 层的 W 增加较少
Prefix tuning MLP
LST 只需要求一个小网络的 W 和 X 梯度 下降一个新增轻量级网络的 W 会增加
LoRA 既要求 W 梯度,还得求全层的 X 梯度,且计算量增多 下降少量的 W 没有