summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMinhui Huang <32522773+mhhuang95@users.noreply.github.com>2021-09-06 08:06:50 -0700
committerGitHub <noreply@github.com>2021-09-06 17:06:50 +0200
commit96bf1a46e74d6985419e14222afb0b9241a7bb36 (patch)
tree6d2b89760a5e3568a79df5c96bc30439c2e82297
parentc105dcb892de87ae9c6cfcfc5d9c0b14f2933082 (diff)
[MRG] Projection Robust Wasserstein (#267)
* ot.dr: PRW code; text.text_dr: PRW test code. * ot.dr: PRW code; test.test_dr: PRW test code. * fix errors: pep8(3.8) * fix errors: pep8(3.8) * modified readme; prw code review * fix pep error * edit comment * modified math comment Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--README.md3
-rw-r--r--ot/dr.py114
-rw-r--r--test/test_dr.py37
3 files changed, 154 insertions, 0 deletions
diff --git a/README.md b/README.md
index 20e0606..6a2cf15 100644
--- a/README.md
+++ b/README.md
@@ -198,6 +198,7 @@ The contributors to this library are
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
+* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -283,3 +284,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
diff --git a/ot/dr.py b/ot/dr.py
index b7a1af0..64588cf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance [32]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure \mu
+ Y : ndarray, shape (n, d)
+ Samples from measure \nu
+ a : ndarray, shape (n, )
+ weights for measure \mu
+ b : ndarray, shape (n, )
+ weights for measure \nu
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U
diff --git a/test/test_dr.py b/test/test_dr.py
index c5df287..fa75a18 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -1,6 +1,7 @@
"""Tests for module dr on Dimensionality Reduction """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -57,3 +58,39 @@ def test_wda():
projwda(xs)
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_prw():
+ d = 100 # Dimension
+ n = 100 # Number samples
+ k = 3 # Subspace dimension
+ dim = 3
+
+ def fragmented_hypercube(n, d, dim):
+ assert dim <= d
+ assert dim >= 1
+ assert dim == int(dim)
+
+ a = (1. / n) * np.ones(n)
+ b = (1. / n) * np.ones(n)
+
+ # First measure : uniform on the hypercube
+ X = np.random.uniform(-1, 1, size=(n, d))
+
+ # Second measure : fragmentation
+ tmp_y = np.random.uniform(-1, 1, size=(n, d))
+ Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0])
+ return a, b, X, Y
+
+ a, b, X, Y = fragmented_hypercube(n, d, dim)
+
+ tau = 0.002
+ reg = 0.2
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1)
+
+ U0 = np.random.randn(d, k)
+ U0, _ = np.linalg.qr(U0)
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1)