返回博客列表

Untitled

推土机距离问题(Earth Mover's Distance)

假设地面上有 m 个土堆,第 i 个土堆有 r_i 数量的土;同时另一边有 n 个坑,第 j 个坑可以容纳 c_j 数量的土。假设所有的土能刚好被所有的坑填满,那么就有关系:

\sum_i{r_i}=\sum_j{c_j}\tag{1.1}

我们现在要把土从 m 搬到 n传输方案就可以用一个 m\times n 的矩阵 Y=[\gamma_{i,j}]{m\times n} 表示,其中 \gamma{i,j} 表示从第 i 个土堆搬到第 j 个坑的土的数量。

矩阵的第 i 行表示第 i 个土推搬到 n 个坑的分配方案,由于假设刚好能填满,所以有关系:

r_i=\sum_j{\gamma_{i,j}}\tag{1.2}

而矩阵的第 j 列表示第 j 个坑收到 m 个土堆的接收方案,由于假设刚好能被填满,所以有关系:
c_i=\sum_i{\gamma_{i,j}}\tag{1.3}

因此传输方案 Y 满足(1.2)和(1.3)的约束。

假设从 ij 搬运一单位的土需要花费 c_{i,j} ,那么所有的单位花费也能用一个 m\times n 的矩阵 C=[c_{i,j}]{m\times n} 来表示。那么传输方案 Y 的总花费可以表示为 W=Y\odot C=\sum{i,j}\gamma_{i,j}c_{i,j} ,最优传输的目的就是找到花费最小的传输方案,因此可以写成:

W[\gamma^*]=\inf_{\gamma} \sum_{i,j}\gamma_{i,j}c_{i,j}\tag{1.4}

如果把 mn 进行归一化,也就是 \sum_i{r_i}=\sum_j{c_j}=1 ,那么 mn 可以看成两个离散概率分布,搬土的过程就是把分布 p_{m}(i)=r_i 搬到 q_n(j)=c_j ,最优传输就是研究以最小的代价将分布 p_{m} 转化为 q_n

Wasserstein距离

把上述问题推广到连续的情况,考虑归一化后的情况,那么(1.1)可以改写为:

\int p(x)dx =\int q(x)dx=1\tag{1.5}

(1.2)式为:
p(x)=\int \gamma(x,y)dy\tag{1.6}

(1.3)式为:
p(y)=\int \gamma(x,y)dx\tag{1.7}

在这种情况下,传输方案 \gamma(x,y) 被赋予了联合概率密度的意义,它保证了“所有的坑能被所有的土填满”这一约束。最终(1.4)式可以写为:
W[\gamma^]=\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)c(x,y)dxdy\tag{1.8}

由于 \gamma(x,y) 可以看成联合概率密度(概率测度),因此有的文献还记为:
W[\gamma^]=\inf_{\gamma \in \prod [p,q] } \int c(x,y)d\gamma(x,y)\tag{1.8*}

因此,最优传输的目的就是在所有可能的联合概率密度里面,找到花费最小的,使之能够以最小的代价把分布从 p(x) 转移到 q(x)

如果我们取花费函数为欧式距离时 c(x,y)=|x-y|^\rho ,记 W^\rho=(W[\gamma^])^{1/\rho} ,则:

W^\rho=\left[\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)|x-y|^\rho dxdy\right]^{1/\rho}\tag{1.9}

称为
Wasserstein- \rho 距离*,若 \rho=1 则:
W^1=\inf_{\gamma \in \prod [p,q] } \int \gamma(x,y)|x-y| dxdy\tag{1.10}

称为Wasserstein距离或者W距离,后面的讨论主要也是以W距离为主。

如果花费为正,那么W距离恒大于等于0;并且当 p(x)=q(x) 时, \gamma(x,y)=0,x\ne y ,而 c(x,y)=0,x=y ,所以W距离此时为0。这说明(1.10)式可以看成一种衡量两个分布 p(x)q(x) 的“距离”,虽然这种“距离”不满足距离空间的定义。

对偶问题

(1.10)是一个带有下确界的优化函数,有时候不好操作,但我们可以利用线性规划的对偶性将其进行转化一下。首先我们可以把 \gamma(x,y)c(x,y) 拆成无穷长的列向量:

\Gamma = \begin{bmatrix} \gamma(x_1, y_1) \ \gamma(x_1, y_2) \ \vdots \ \hline \gamma(x_2, y_1) \ \gamma(x_2, y_2) \ \vdots \ \hline \vdots \ \hline \gamma(x_n, y_1) \ \gamma(x_n, y_2) \ \vdots\ \end{bmatrix},\qquad C = \begin{bmatrix} c(x_1, y_1) \ c(x_1, y_2) \ \vdots \ \hline c(x_2, y_1) \ c(x_2, y_2) \ \vdots \ \hline \vdots \ \hline c(x_n, y_1) \ c(x_n, y_2) \ \vdots\ \end{bmatrix} \tag{1.11}

那么(1.10)就能写成列向量的内积形式:
W^1=\inf_{\Gamma } \Gamma \cdot C\tag{1.12}

对应的约束(1.6)和(1.7)就可以写为:
\underbrace{\left[ \begin{array}{ccc|ccc|c|ccc} 1 & 1 & \cdots & 0 & 0 &\cdots & \cdots & 0 & 0 & \cdots \ 0 & 0 & \cdots & 1 & 1 &\cdots & \cdots & 0 & 0 & \cdots \ \vdots & \vdots & \ddots & \vdots & \vdots & \ddots & \cdots & \vdots & \vdots & \ddots \ 0 & 0 & \cdots & 0 & 0 &\cdots & \cdots & 1 & 1 & \cdots \ \hline 1 & 0 & \cdots & 1 & 0 &\cdots & \cdots & 1 & 0 & \cdots \ 0 & 1 & \cdots & 0 & 1 &\cdots & \cdots & 0 & 1 & \cdots \ \vdots & \vdots & \ddots & \vdots & \vdots & \ddots & \cdots & \vdots & \vdots & \ddots \ 0 & 0 & \cdots & 0 & 0 &\cdots & \cdots & 0 & 0 & \cdots \ \end{array} \right]}{A} \underbrace{ \begin{bmatrix} \gamma(x_1, y_1) \ \gamma(x_1, y_2) \ \vdots \ \hline \gamma(x_2, y_1) \ \gamma(x_2, y_2) \ \vdots \ \hline \vdots \ \hline \gamma(x_n, y_1) \ \gamma(x_n, y_2) \ \vdots\ \end{bmatrix}}{\Gamma} = \underbrace{\left[ \begin{array}{c} p(x_1) \ p(x_2) \ \vdots \ p(x_n) \ \hline q(y_1) \ q(y_2) \ \vdots \ q(y_n) \ \end{array} \right]}{b} \tag{1.13}

那么最终(1.10)就可以写成线性规划的形式:
\min{\Gamma} \left{ \Gamma\cdot C \mid A\Gamma = b, \Gamma \geq 0 \right}\tag{1.14}

根据线性规划的(强)对偶性(可以参考《https://zhuanlan.zhihu.com/p/681165783》):
\min_{x} {c^T x \mid Ax = b, x \geq 0}=\max_{y} {b^T y \mid A^T y \leq c} \tag{1.15}

(1.14)的对偶问题为:
\max_{F} {b^T F \mid A^T F \leq C}\tag{1.16}

其中, F 是跟 b 形状一致的列向量。

为什么最优传输

那么我们为什么要计算 optimal transport 呢?如果为了 measure 概率分布之间的距离,有很多现成的 measurements 可以用, 比如非常简单的 KL 散度:


KL(p,q)=\sum_{i}p_i⋅log⁡\frac{p_i}{q_i}

除了 “无法处理两个分布的支撑集不相交的情况”以及“不满足对称性” 等原因之外,一个重要的原因就是这种 逐点计算的度量 没有考虑分布内的结构信息。 所谓的结构信息,就是分布内的联系。例如在 KL 散度中,p_i,i=0,1,... 彼此之间都是独立计算最后加起来的, 而大部分情况下它们并不是独立的。

就以我们常见的分类任务为例,分类任务通常用交叉熵损失来度量模型预测和样本标签之间的距离, 交叉熵损失实际上就是在计算 onehot 化的标签和模型预测之间的 KL 散度。 这种逐点计算的损失函数(不论是交叉熵还是 L2)都无法考虑分布内不同事件的相关性。 例如将“汽车”误分类成“卡车”显然没有把“汽车”误分类成“斑马”严重。 但是用 KL 散度来度量的话,这两种错误的损失是一样的。

概率分布内的结构信息可以通过最优传输的距离矩阵 M “嵌入” 到距离度量中。 还是以分类为例,我们可以让 m_{汽车,卡车} 远小于 m_{汽车,斑马}

评论