From 96bf1a46e74d6985419e14222afb0b9241a7bb36 Mon Sep 17 00:00:00 2001 From: Minhui Huang <32522773+mhhuang95@users.noreply.github.com> Date: Mon, 6 Sep 2021 08:06:50 -0700 Subject: [MRG] Projection Robust Wasserstein (#267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ot.dr: PRW code; text.text_dr: PRW test code. * ot.dr: PRW code; test.test_dr: PRW test code. * fix errors: pep8(3.8) * fix errors: pep8(3.8) * modified readme; prw code review * fix pep error * edit comment * modified math comment Co-authored-by: RĂ©mi Flamary --- ot/dr.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) (limited to 'ot') diff --git a/ot/dr.py b/ot/dr.py index b7a1af0..64588cf 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -10,6 +10,7 @@ Dimension reduction with OT """ # Author: Remi Flamary +# Minhui Huang # # License: MIT License @@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): return (X - mx.reshape((1, -1))).dot(Popt) return Popt, proj + + +def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): + r""" + Projection Robust Wasserstein Distance [32] + + The function solves the following optimization problem: + + .. math:: + \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) + + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold + - :math:`H(\pi)` is entropy regularizer + - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively + + Parameters + ---------- + X : ndarray, shape (n, d) + Samples from measure \mu + Y : ndarray, shape (n, d) + Samples from measure \nu + a : ndarray, shape (n, ) + weights for measure \mu + b : ndarray, shape (n, ) + weights for measure \nu + tau : float + stepsize for Riemannian Gradient Descent + U0 : ndarray, shape (d, p) + Initial starting point for projection. + reg : float, optional + Regularization term >0 (entropic regularization) + k : int + Subspace dimension + stopThr : float, optional + Stop threshold on error (>0) + verbose : int, optional + Print information along iterations. + + Returns + ------- + pi : ndarray, shape (n, n) + Optimal transportation matrix for the given parameters + U : ndarray, shape (d, k) + Projection operator. + + References + ---------- + .. [32] Huang, M. , Ma S. & Lai L. (2021). + A Riemannian Block Coordinate Descent Method for Computing + the Projection Robust Wasserstein Distance, ICML. + """ # noqa + + # initialization + n, d = X.shape + m, d = Y.shape + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + u = np.ones(n) / n + v = np.ones(m) / m + ones = np.ones((n, m)) + + assert d > k + + if U0 is None: + U = np.random.randn(d, k) + U, _ = np.linalg.qr(U) + else: + U = U0 + + def Vpi(X, Y, a, b, pi): + # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }. + A = X.T.dot(pi).dot(Y) + return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T + + err = 1 + iter = 0 + + while err > stopThr and iter < maxiter: + + # Projected cost matrix + UUT = U.dot(U.T) + M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot( + np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T)) + + A = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=A) + np.exp(A, out=A) + + # Sinkhorn update + Ap = (1 / a).reshape(-1, 1) * A + AtransposeU = np.dot(A.T, u) + v = np.divide(b, AtransposeU) + u = 1. / np.dot(Ap, v) + pi = u.reshape((-1, 1)) * A * v.reshape((1, -1)) + + V = Vpi(X, Y, a, b, pi) + + # Riemannian gradient descent + G = 2 / reg * V.dot(U) + GTU = G.T.dot(U) + xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient + U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition + + grad_norm = np.linalg.norm(xi) + err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1)) + + f_val = np.trace(U.T.dot(V.dot(U))) + if verbose: + print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val) + + iter = iter + 1 + + return pi, U -- cgit v1.2.3