推土机距离问题(Earth Mover's Distance)¶
假设地面上有 m 个土堆,第 i 个土堆有 r_i 数量的土;同时另一边有 n 个坑,第 j 个坑可以容纳 c_j 数量的土。假设所有的土能刚好被所有的坑填满,那么就有关系:
\sum_i{r_i}=\sum_j{c_j}\tag{1.1}
我们现在要把土从
m 搬到
n ,
传输方案就可以用一个
m\times n 的矩阵
Y=[\gamma_{i,j}]{m\times n} 表示,其中
\gamma{i,j} 表示从第
i 个土堆搬到第
j 个坑的土的数量。
矩阵的第 i 行表示第 i 个土推搬到 n 个坑的分配方案,由于假设刚好能填满,所以有关系:
r_i=\sum_j{\gamma_{i,j}}\tag{1.2}
而矩阵的第
j 列表示第
j 个坑收到
m 个土堆的接收方案,由于假设刚好能被填满,所以有关系:
c_i=\sum_i{\gamma_{i,j}}\tag{1.3}
因此传输方案
Y 满足(1.2)和(1.3)的约束。
假设从 i 到 j 搬运一单位的土需要花费 c_{i,j} ,那么所有的单位花费也能用一个 m\times n 的矩阵 C=[c_{i,j}]{m\times n} 来表示。那么传输方案 Y 的总花费可以表示为 W=Y\odot C=\sum{i,j}\gamma_{i,j}c_{i,j} ,最优传输的目的就是找到花费最小的传输方案,因此可以写成:
W[\gamma^*]=\inf_{\gamma} \sum_{i,j}\gamma_{i,j}c_{i,j}\tag{1.4}
如果把
m 和
n 进行归一化,也就是
\sum_i{r_i}=\sum_j{c_j}=1 ,那么
m 和
n 可以看成两个离散概率分布,搬土的过程就是把分布
p_{m}(i)=r_i 搬到
q_n(j)=c_j ,最优传输就是研究以最小的代价将分布
p_{m} 转化为
q_n 。
Wasserstein距离¶
把上述问题推广到连续的情况,考虑归一化后的情况,那么(1.1)可以改写为:
\int p(x)dx =\int q(x)dx=1\tag{1.5}
(1.2)式为:
p(x)=\int \gamma(x,y)dy\tag{1.6}
(1.3)式为:
p(y)=\int \gamma(x,y)dx\tag{1.7}
在这种情况下,传输方案
\gamma(x,y) 被赋予了联合概率密度的意义,它保证了“所有的坑能被所有的土填满”这一约束。最终(1.4)式可以写为:
W[\gamma^]=\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)c(x,y)dxdy\tag{1.8}
由于
\gamma(x,y) 可以看成联合概率密度(概率测度),因此有的文献还记为:
W[\gamma^]=\inf_{\gamma \in \prod [p,q] } \int c(x,y)d\gamma(x,y)\tag{1.8*}
因此,最优传输的目的就是在所有可能的联合概率密度里面,找到花费最小的,使之能够以最小的代价把分布从
p(x) 转移到
q(x) 。
如果我们取花费函数为欧式距离时 c(x,y)=|x-y|^\rho ,记 W^\rho=(W[\gamma^])^{1/\rho} ,则:
W^\rho=\left[\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)|x-y|^\rho dxdy\right]^{1/\rho}\tag{1.9}
称为
Wasserstein- \rho 距离*,若
\rho=1 则:
W^1=\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)|x-y| dxdy\tag{1.10}
称为Wasserstein距离或者W距离,后面的讨论主要也是以W距离为主。
如果花费为正,那么W距离恒大于等于0;并且当 p(x)=q(x) 时, \gamma(x,y)=0,x\ne y ,而 c(x,y)=0,x=y ,所以W距离此时为0。这说明(1.10)式可以看成一种衡量两个分布 p(x) 和 q(x) 的“距离”,虽然这种“距离”不满足距离空间的定义。
对偶问题¶
(1.10)是一个带有下确界的优化函数,有时候不好操作,但我们可以利用线性规划的对偶性将其进行转化一下。首先我们可以把 \gamma(x,y) 和 c(x,y) 拆成无穷长的列向量:
\Gamma = \begin{bmatrix} \gamma(x_1, y_1) \ \gamma(x_1, y_2) \ \vdots \ \hline \gamma(x_2, y_1) \ \gamma(x_2, y_2) \ \vdots \ \hline \vdots \ \hline \gamma(x_n, y_1) \ \gamma(x_n, y_2) \ \vdots\ \end{bmatrix},\qquad C = \begin{bmatrix} c(x_1, y_1) \ c(x_1, y_2) \ \vdots \ \hline c(x_2, y_1) \ c(x_2, y_2) \ \vdots \ \hline \vdots \ \hline c(x_n, y_1) \ c(x_n, y_2) \ \vdots\ \end{bmatrix} \tag{1.11}
那么(1.10)就能写成列向量的内积形式:
W^1=\inf_{\Gamma } \Gamma \cdot C\tag{1.12}
对应的约束(1.6)和(1.7)就可以写为:
\underbrace{\left[ \begin{array}{ccc|ccc|c|ccc} 1 & 1 & \cdots & 0 & 0 &\cdots & \cdots & 0 & 0 & \cdots \ 0 & 0 & \cdots & 1 & 1 &\cdots & \cdots & 0 & 0 & \cdots \ \vdots & \vdots & \ddots & \vdots & \vdots & \ddots & \cdots & \vdots & \vdots & \ddots \ 0 & 0 & \cdots & 0 & 0 &\cdots & \cdots & 1 & 1 & \cdots \ \hline 1 & 0 & \cdots & 1 & 0 &\cdots & \cdots & 1 & 0 & \cdots \ 0 & 1 & \cdots & 0 & 1 &\cdots & \cdots & 0 & 1 & \cdots \ \vdots & \vdots & \ddots & \vdots & \vdots & \ddots & \cdots & \vdots & \vdots & \ddots \ 0 & 0 & \cdots & 0 & 0 &\cdots & \cdots & 0 & 0 & \cdots \ \end{array} \right]}{A} \underbrace{ \begin{bmatrix} \gamma(x_1, y_1) \ \gamma(x_1, y_2) \ \vdots \ \hline \gamma(x_2, y_1) \ \gamma(x_2, y_2) \ \vdots \ \hline \vdots \ \hline \gamma(x_n, y_1) \ \gamma(x_n, y_2) \ \vdots\ \end{bmatrix}}{\Gamma} = \underbrace{\left[ \begin{array}{c} p(x_1) \ p(x_2) \ \vdots \ p(x_n) \ \hline q(y_1) \ q(y_2) \ \vdots \ q(y_n) \ \end{array} \right]}{b} \tag{1.13}
那么最终(1.10)就可以写成线性规划的形式:
\min{\Gamma} \left{ \Gamma\cdot C \mid A\Gamma = b, \Gamma \geq 0 \right}\tag{1.14}
根据线性规划的(强)对偶性(可以参考《
https://zhuanlan.zhihu.com/p/681165783》):
\min_{x} {c^T x \mid Ax = b, x \geq 0}=\max_{y} {b^T y \mid A^T y \leq c} \tag{1.15}
(1.14)的对偶问题为:
\max_{F} {b^T F \mid A^T F \leq C}\tag{1.16}
其中,
F 是跟
b 形状一致的列向量。
为什么最优传输¶
那么我们为什么要计算 optimal transport 呢?如果为了 measure 概率分布之间的距离,有很多现成的 measurements 可以用, 比如非常简单的 KL 散度:
KL(p,q)=\sum_{i}p_i⋅log\frac{p_i}{q_i}
除了 “无法处理两个分布的支撑集不相交的情况”以及“不满足对称性” 等原因之外,一个重要的原因就是这种 逐点计算的度量
没有考虑分布内的结构信息。 所谓的结构信息,就是分布内的联系。例如在 KL 散度中,
p_i,i=0,1,... 彼此之间都是独立计算最后加起来的, 而大部分情况下它们并不是独立的。
就以我们常见的分类任务为例,分类任务通常用交叉熵损失来度量模型预测和样本标签之间的距离, 交叉熵损失实际上就是在计算 onehot 化的标签和模型预测之间的 KL 散度。 这种逐点计算的损失函数(不论是交叉熵还是 L2)都无法考虑分布内不同事件的相关性。 例如将“汽车”误分类成“卡车”显然没有把“汽车”误分类成“斑马”严重。 但是用 KL 散度来度量的话,这两种错误的损失是一样的。
概率分布内的结构信息可以通过最优传输的距离矩阵 M “嵌入” 到距离度量中。 还是以分类为例,我们可以让 m_{汽车,卡车} 远小于 m_{汽车,斑马} 。
评论