WGAN

假设 $F$ 的形式为: $$F = \left[ \begin{array}{c} f_1\ f_2 \ \vdots \ f_n \ \hline g_1 \ g_2 \ \vdots \ g_n \ \end{array} \right] = \left[ \begin{array}{c} f(x_1) \ f(x_2) \ \vdots \ f(x_n) \ \hline g(y_1) \ g(y_2) \ \vdots \ g(y_n) \ \end{array} \right]\tag{2.1}$$
为什么这样设?第一个等号说明, $F$ 的形状跟 $b$ 一样,因此每个位置 $x_1,x_2,…,x_n,y_1,y_2,…y_n$ 上对应一个待求解的未知变量。第二个等号说明我们可以用已知的 $f_i$ 和 $x_i$ 或 $g_i$ 和 $y_i$ 来得到一个统一的 $f$ 和 $g$ ,这样有利于接下来的推导分析。

$b^T F$ 就可以展开为: $$b^T F=\sum_{i,j} p(x_i)f(x_i)+q(y_i)g(y_i)\tag{2.2}$$
因此写成连续形式则有: $$b^T F=\int [p(x)f(x) +q(x)g(x)] dx\tag{2.3}$$

此外对于约束 $A^\top F \leq C$ : $$\underbrace{\left[ \begin{array}{cccc|cccc} 1 & 0 & … & 0 & 1 & 0 & … & 0\ 1 & 0 & … & 0 & 0 & 1 & … & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline 0 & 1 & … & 0 & 1 & 0 & … & 0\ 0 & 1 & … & 0 & 0 & 1 & … & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline 0 & 0 & … & 1 & 1 & 0 & … & 0\ 0 & 0 & … & 1 & 0 & 1 & … & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \end{array} \right]}{A^T} \underbrace{\left[ \begin{array}{c} f(x_1) \ f(x_2) \ \vdots \ f(x_n) \ \hline g(y_1) \ g(y_2) \ \vdots \ g(y_n) \ \end{array} \right]}{F} \leq \underbrace{\begin{bmatrix} c(x_1, y_1) \ c(x_1, y_2) \ \vdots \ \hline c(x_2, y_1) \ c(x_2, y_2) \ \vdots \ \hline \vdots \ \hline c(x_n, y_1) \ c(x_n, y_2) \ \vdots\ \end{bmatrix}}_C \tag{2.4}$$
展开后,实际上说的就是: $$f(x_i)+g(y_j)\leq c(x_i,y_j)\tag{2.5}$$

那么(1.16)式就转化为了: $$\max_{f,g} \left{\int [p(x)f(x) +q(x)g(x)] dx \mid f(x) + g(y) \leq c(x,y)\right}\tag{2.6}$$
这也是(1.10)W距离对应的对偶问题,即求解两个分布的W距离相当于求解(2.4)。

对于特殊情况 $x=y$ ,此时 $c(x,x)=0$ ,则: $$f(x)+g(x)\leq 0 \Rightarrow g(x)\leq -f(x)\tag{2.7}$$
这启发我们,如果令 $g(x)= -f(x)$ ,(2.6)式就能写成: $$\max_{f} \left{\int [p(x)f(x) -q(x)f(x)] dx \mid f(x) - f(y) \leq c(x,y)\right}\tag{2.8}$$
凑巧的是,(2.6)刚好等价于(2.8),下面是证明

$g(x)= -f(x)$ 是(2.7)中 $g(x)\leq -f(x)$ 的特殊情况,因此容易知道(2.8)的解一定在(2.6)中,记为 $(2.6)\supseteq(2.8)$ 。
另外,由于 $p(x)f(x) +q(x)g(x)\leq p(x)f(x) -q(x)f(x)$ ,因此(2.6)的目标一定被(2.8)包含,所以只需证明(2.8)和(2.6)的约束是等价的,那么就有 $(2.6)\subseteq(2.8)$ ,综上就能推出 $(2.6)=(2.8)$ 了。但对于约束,由于 $ f(x) + g(y) \leq f(x) - f(y)$ ,所以 $f(x) - f(y)$ 不一定会小于 $c(x,y)$ ,因此关于这部分需要换个思路。
假设(2.6)的最优解为 $f^(x)$ 和 $g^(y)$ ,(2.8)的目标此时就为 $p(x)f^(x) -q(x)f^(x)$ ,那么可以想办法构造出一个新解 $f^{}(x)$ ,其不仅满足目标,且符合约束 $f^{}(x) - f^{}(y) \leq c(x,y)$ 。由于 $f^(x) +g^(y)\leq c(x,y)$ ,则 $f^(x) \leq c(x,y)-g^(y) $ ,令: $$\begin{align} f^{}(x)&=\min_{y} {c(x,y)-g^(y) }\ z(x) &= \operatorname{argmin}_{y} {c(x,y)-g^(y) } \end{align}\tag{2.9}$$
显然 $f^
(x) \leq f^{}(x)$ ,那么(2.6)的目标: $$p(x)f^(x) +q(x)g^(x)\leq p(x)f^{}(x) +q(x)g^(x)\tag{2.10}$$
因为为 $f^
(x)$ 和 $g^(y)$ 已经是最优解了,所以(2.10)只能取等号, $f^{**}(x)$ 和 $g^(y)$ 也是一对最优解,那么约束: $$\begin{align} f^{}(x) - f^{}(y) &= [c(x, z(x)) - g^{}(z(x))] - [c(y, z(y)) - g^{}(z(y))] \ &\leq [c(x, z(y)) - g^{}(z(y))] - [c(y, z(y)) - g^{}(z(y))] \ &= c(x, z(y)) - c(y, z(y)) \ &\leq c(x, y) \end{align}\tag{2.11}$$
其中,第一个等号根据(2.9)的定义;第二个不等号是因为 $z(x)$ 已是最优解,换成其它肯定会放大;最后一个不等号是由于距离的三角不等式,而W距离中的 $c(x, y)$ 定义为欧式距离,这显然满足。
上述说明,(2.6)的解一定包含在(2.8)中(目标符合,约束也符合),因此有 $(2.6)\subseteq(2.8)$ ,综合起来就有 $(2.6)=(2.8)$ ,完成了证明

最终我们把W距离转化为了(2.8)式。对于约束 $f(x) - f(y) \leq c(x,y)$ ,根据欧氏距离的对称性还有 $f(x) - f(y) \geq -c(x,y)$ 写成期望形式: $$\max_{f, |f(x) - f(y)| \leq |x-y|} \left{\mathbb{E}{p(x)}[f(x)] -\mathbb{E}{q(x)}[f(x)] \right}\tag{2.12}$$
对于生成模型,我们会要求两个分布的距离尽量靠近,所以如果再最小化 $p(x)$ 和 $q(x)$ 的W距离,假设 $g$ 为生成器, $q(x)=q(g(z)),z\sim N(0,1)$ ,那么最终的形式为: $$\min_{g}\max_{f, |f(x) - f(y)| \leq |x-y|} \left{\mathbb{E}{p(x)}[f(x)] -\mathbb{E}{q(g(z))}[f(g(z))] \right}\tag{2.13}$$
这正是WGAN,因此WGAN的损失函数就是在缩短两个分布的W距离。

扩散模型

根据之前的推导,WGAN主要是在优化W距离,而扩散模型主要是优化KL距离(散度),那么这两个距离之间会有关联吗?2022年的一篇文章《https://arxiv.org/abs/2212.06359》就介绍了扩散模型的得分匹配损失(实质上也是KL距离的上界#ref_4)是W距离的一个上界,因此在某种程度上,优化得分就等于优化W距离,这样就将扩散模型和WGAN联系到一起了。

这篇文章介绍的最核心的定理如下:

定理1. 假设 $p_t(x)$ 服从以下正向SDE演化过程: $$dx=f(x,t)dt+g(t)dw,\qquad t\in[0,T]\tag{2.14}$$
从 $t=0$ 开始,定义 $p_0(x)$ 为数据分布。令 $ s_{\theta}(t, x)$ 是由(2.14)经过得分匹配损失训练得到的。假设 $q_t(x)$ 服从以下逆向SDE演化过程: $$dx=[f(x,t)-g(t)^2s_\theta(x,t)]dt+g(t)dw,\qquad t\in[0,T]\tag{2.15}$$
从 $t=T$ 开始,定义 $q_T(x)$ 为指定先验分布(例如标准高斯噪声)。那么有以下关系: $$W_2(p_0, q_0) \leq \int_{0}^{T} g(t)I(t)\mathbb{E}{p_t} \left[ | \nabla \log p_t(x) - s{\theta}(t, x) |^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)\tag{2.16}$$
$W_2(p_0, q_0)$ 表示 $p_0(x)$ 和 $q_0(x)$ 之间的W-2距离; $I(t) = \exp \left( \int_{0}^{t} \left( L_f(r) + L_s(r) g(r)^2 \right) dr \right)$ 单调递增,其中两个非负函数 $L_f(t)$ 和 $L_s(t)$ 来自论文的前提假设, $f(x,t)$ 满足Lipschitz约束, $s_\theta(x,t)$ 满足单边Lipschitz约束: $$\begin{align} |f(x,t)-f(y,t)|&\leq L_f(t)|x-y|\ (s_\theta(x,t)-s_\theta(y,t))\cdot (x-y)&\leq L_s(t)|x-y|^2 \end{align}\tag{2.17}$$
定理1告诉我们,如果我们优化得分匹配损失,那么也相当于优化W距离,所以扩散模型不但在优化两个分布之间的KL距离,还在悄悄优化W距离,也就揭示扩散模型模型和WGAN的联系了。

要完整这个定理需要用到很多最优传输中的引理,作者在论文中也只是简单引用,因此这里就简单介绍一下作者主要的证明思路

根据#ref_5的定理8.4.7和#ref_6的推论5.25,有: $$-\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\frac{dy}{dt}-\frac{dx}{dt}) \right]\tag{2.18}$$
这个式子是证明的**核心,**其中 $\pi_t(x,y)$ 表示 $p_t(x)$ 到 $q_t(y)$ 的最优传输策略, $dx/dt$ 和 $dy/dt$ 分别是 $p_t(x)$ 和 $q_t(y)$ 对应路径 $x$ 和 $y$ 对 $t$ 的全微分,即为概率流ODE。
(2.14)对应的概率流ODE为: $$\frac{dx}{dt}=f(x,t)-g(t)^2\nabla
{x} \log p_t(x)\tag{2.19}$$

(2.15)对应的概率流ODE为: $$\frac{dy}{dt}=f(y,t) - g(t)^2 s_\theta(y,t) + \frac{1}{2}g(t)^2 \nabla_{y} \log q_t(y)\tag{2.20}$$
带入(2.18),有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}&=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s_\theta(x,t)-s_\theta(y,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{x}p_t(x)-s_\theta(x,t)) \right]\ &\quad+\frac{g(t)^2}{2}\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\ \end{align}\tag{2.21}$$
右边第一项和第二项根据(2.17)的约束,可以很容易得到: $$\begin{align} \mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]&\leq L_f(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=L_f(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.22}$$
和: $$\begin{align} g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s\theta(x,t)-s_\theta(y,t)) \right]&\leq g(t)^2L_s(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=g(t)^2L_s(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.23}$$
第三项利用积分Cauchy-Schwarz不等式: $$\begin{align} g(t)^2\mathbb{E}
{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla_{x}p_t(x)-s_\theta(x,t)) \right] &\leq g(t)^2 \mathbb{E}{\pi_t(x,y)}[|x-y|^2]^{\frac{1}{2}}\mathbb{E}{\pi_t(x,y)}[|\log \nabla_{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}}\ &=g(t)^2W_2(p_t(x),q_t(y))\mathbb{E}{p_t(x)}[|\log \nabla{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}} \end{align}\tag{2.24}$$
第四项根据论文附录的引理2,有: $$\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\leq 0\tag{2.25}$$
综合(2.22)-(2.25),最终有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt} & \leq L_f(t)W_2^2(p_t(x),q_t(y)) \ &\quad+ g(t)^2L_s(t)W_2^2(p_t(x),q_t(y))\ &\quad+g(t)^2W_2(p_t(x),q_t(y))b_t^{\frac{1}{2}}\ \end{align}\tag{2.26}$$
简记 $b_t=\mathbb{E}{p_t(x)}[|\log \nabla{x}p_t(x)-s_\theta(x,t)|^2]$ ,由于: $$-\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}=-W_2(p_t(x),q_t(y))\frac{dW_2(p_t(x),q_t(y))}{dt}\tag{2.27}$$
因此(2.26)两端可以整理为: $$-\frac{dW_2(p_t(x),q_t(y))}{dt}\leq (L_f(t)+g(t)^2L_s(t))W_2(p_t(x),q_t(y)) +g(t)^2b_t^{\frac{1}{2}} \tag{2.28}$$
这是非齐次线性一阶微分方程,利用常数变易法,设: $$W_2(p_t(x),q_t(y))=C_t\exp\left(\int_{t}^0 L_f(r)+g(r)^2L_s(r)dr\right)=C_t/I(t)\tag{2.29}$$
带入(2.28)有: $$-\frac{dC_t}{dt}\leq \exp\left(\int_0^t L_f(r)+g(r)^2L_s(r)dr\right) g(t)^2b_t^{\frac{1}{2}}=I(t)g(t)^2b_t^{\frac{1}{2}}\tag{2.30}$$
两边同时对 $t$ 从0积到 $T$ ,则: $$C_0\leq \int_0^T I(t)g(t)^2b_t^{\frac{1}{2}}dt+C_T\tag{2.31}$$
由于 $W_2(p_T(x),q_T(y))=C_T/I(T)$ ,则 $C_T=I(T)W_2(p_T(x),q_T(y))$ ,最终: $$W_2(p_0(x),q_0(y))=C_0\leq \int_0^T I(t)g(t)^2b_t^{\frac{1}{2}}dt+I(T)W_2(p_T(x),q_T(y))\tag{2.32}$$
完成了证明

注:苏老师在博客《https://spaces.ac.cn/archives/9467》中对于ODE情况给出了自己的证明,和原论文的主要的差别在于(2.18)式的推导。苏老师把期望的 $t$ 时刻最优传输方案 $\pi_t(x,y)$ 改为了由 $p_T(z)$ 通过 $dx/dt$ 和 $dy/dt$ ODE映射得到的 $\gamma_t(x(z),y(z))$ ,这样的好处是可以把W-2距离转化成关于中间变量 $z$ 的期望,与时间无关,那么就可以很容易的对 $t$ 求导得到类(2.18)式,后续的操作是一样的。但在SDE情况下不考虑最优传输方案会出现误差,无法使用类似的思路进行推导。

其实这种距离对 $t$ 求导数的操作在#ref_7#ref_8#ref_4中都有用过,当时是利用KL散度对 $t$ 的导数来推出它的一个上界为得分匹配,现在看来跟本文的结果异曲同工了,只是W距离的推导需要最优传输的背景,相对更加复杂。