diff options
author | Minhui Huang <32522773+mhhuang95@users.noreply.github.com> | 2021-09-06 08:06:50 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-06 17:06:50 +0200 |
commit | 96bf1a46e74d6985419e14222afb0b9241a7bb36 (patch) | |
tree | 6d2b89760a5e3568a79df5c96bc30439c2e82297 | |
parent | c105dcb892de87ae9c6cfcfc5d9c0b14f2933082 (diff) |
[MRG] Projection Robust Wasserstein (#267)
* ot.dr: PRW code; text.text_dr: PRW test code.
* ot.dr: PRW code; test.test_dr: PRW test code.
* fix errors: pep8(3.8)
* fix errors: pep8(3.8)
* modified readme; prw code review
* fix pep error
* edit comment
* modified math comment
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | ot/dr.py | 114 | ||||
-rw-r--r-- | test/test_dr.py | 37 |
3 files changed, 154 insertions, 0 deletions
@@ -198,6 +198,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -283,3 +284,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). @@ -10,6 +10,7 @@ Dimension reduction with OT """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Minhui Huang <mhhuang@ucdavis.edu> # # License: MIT License @@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): return (X - mx.reshape((1, -1))).dot(Popt) return Popt, proj + + +def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): + r""" + Projection Robust Wasserstein Distance [32] + + The function solves the following optimization problem: + + .. math:: + \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) + + - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold + - :math:`H(\pi)` is entropy regularizer + - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively + + Parameters + ---------- + X : ndarray, shape (n, d) + Samples from measure \mu + Y : ndarray, shape (n, d) + Samples from measure \nu + a : ndarray, shape (n, ) + weights for measure \mu + b : ndarray, shape (n, ) + weights for measure \nu + tau : float + stepsize for Riemannian Gradient Descent + U0 : ndarray, shape (d, p) + Initial starting point for projection. + reg : float, optional + Regularization term >0 (entropic regularization) + k : int + Subspace dimension + stopThr : float, optional + Stop threshold on error (>0) + verbose : int, optional + Print information along iterations. + + Returns + ------- + pi : ndarray, shape (n, n) + Optimal transportation matrix for the given parameters + U : ndarray, shape (d, k) + Projection operator. + + References + ---------- + .. [32] Huang, M. , Ma S. & Lai L. (2021). + A Riemannian Block Coordinate Descent Method for Computing + the Projection Robust Wasserstein Distance, ICML. + """ # noqa + + # initialization + n, d = X.shape + m, d = Y.shape + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + u = np.ones(n) / n + v = np.ones(m) / m + ones = np.ones((n, m)) + + assert d > k + + if U0 is None: + U = np.random.randn(d, k) + U, _ = np.linalg.qr(U) + else: + U = U0 + + def Vpi(X, Y, a, b, pi): + # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }. + A = X.T.dot(pi).dot(Y) + return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T + + err = 1 + iter = 0 + + while err > stopThr and iter < maxiter: + + # Projected cost matrix + UUT = U.dot(U.T) + M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot( + np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T)) + + A = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=A) + np.exp(A, out=A) + + # Sinkhorn update + Ap = (1 / a).reshape(-1, 1) * A + AtransposeU = np.dot(A.T, u) + v = np.divide(b, AtransposeU) + u = 1. / np.dot(Ap, v) + pi = u.reshape((-1, 1)) * A * v.reshape((1, -1)) + + V = Vpi(X, Y, a, b, pi) + + # Riemannian gradient descent + G = 2 / reg * V.dot(U) + GTU = G.T.dot(U) + xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient + U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition + + grad_norm = np.linalg.norm(xi) + err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1)) + + f_val = np.trace(U.T.dot(V.dot(U))) + if verbose: + print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val) + + iter = iter + 1 + + return pi, U diff --git a/test/test_dr.py b/test/test_dr.py index c5df287..fa75a18 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -1,6 +1,7 @@ """Tests for module dr on Dimensionality Reduction """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Minhui Huang <mhhuang@ucdavis.edu> # # License: MIT License @@ -57,3 +58,39 @@ def test_wda(): projwda(xs) np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_prw(): + d = 100 # Dimension + n = 100 # Number samples + k = 3 # Subspace dimension + dim = 3 + + def fragmented_hypercube(n, d, dim): + assert dim <= d + assert dim >= 1 + assert dim == int(dim) + + a = (1. / n) * np.ones(n) + b = (1. / n) * np.ones(n) + + # First measure : uniform on the hypercube + X = np.random.uniform(-1, 1, size=(n, d)) + + # Second measure : fragmentation + tmp_y = np.random.uniform(-1, 1, size=(n, d)) + Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0]) + return a, b, X, Y + + a, b, X, Y = fragmented_hypercube(n, d, dim) + + tau = 0.002 + reg = 0.2 + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1) + + U0 = np.random.randn(d, k) + U0, _ = np.linalg.qr(U0) + + pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1) |