summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorMinhui Huang <32522773+mhhuang95@users.noreply.github.com>2021-09-06 08:06:50 -0700
committerGitHub <noreply@github.com>2021-09-06 17:06:50 +0200
commit96bf1a46e74d6985419e14222afb0b9241a7bb36 (patch)
tree6d2b89760a5e3568a79df5c96bc30439c2e82297 /ot/dr.py
parentc105dcb892de87ae9c6cfcfc5d9c0b14f2933082 (diff)
[MRG] Projection Robust Wasserstein (#267)
* ot.dr: PRW code; text.text_dr: PRW test code. * ot.dr: PRW code; test.test_dr: PRW test code. * fix errors: pep8(3.8) * fix errors: pep8(3.8) * modified readme; prw code review * fix pep error * edit comment * modified math comment Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/ot/dr.py b/ot/dr.py
index b7a1af0..64588cf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance [32]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure \mu
+ Y : ndarray, shape (n, d)
+ Samples from measure \nu
+ a : ndarray, shape (n, )
+ weights for measure \mu
+ b : ndarray, shape (n, )
+ weights for measure \nu
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U