From 78b44af2434f494c8f9e4c8c91003fbc0e1d4415 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 22 Oct 2020 09:28:53 +0100 Subject: [MRG] Sliced wasserstein (#203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes * Implements sliced wasserstein * Changed formatting of string for py3.5 support * Docstest, expected 0.0 and not 0. * Adressed comments by @rflamary * No 3d plot here * add sliced to the docs * Incorporate comments by @rflamary * add link to pdf Co-authored-by: RĂ©mi Flamary --- README.md | 4 + docs/source/all.rst | 1 + examples/sliced-wasserstein/README.txt | 4 + examples/sliced-wasserstein/plot_variance.py | 84 ++++++++++++++++ ot/__init__.py | 3 +- ot/sliced.py | 144 +++++++++++++++++++++++++++ test/test_sliced.py | 85 ++++++++++++++++ 7 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 examples/sliced-wasserstein/README.txt create mode 100644 examples/sliced-wasserstein/plot_variance.py create mode 100644 ot/sliced.py create mode 100644 test/test_sliced.py diff --git a/README.md b/README.md index e3598f1..6fe528a 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. POT provides the following Machine Learning related solvers: @@ -180,6 +181,7 @@ The contributors to this library are * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + +[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 diff --git a/docs/source/all.rst b/docs/source/all.rst index d7b878f..f1f7075 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -27,6 +27,7 @@ API and modules stochastic unbalanced partial + sliced .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt new file mode 100644 index 0000000..a575345 --- /dev/null +++ b/examples/sliced-wasserstein/README.txt @@ -0,0 +1,4 @@ + + +Sliced Wasserstein Distance +--------------------------- \ No newline at end of file diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py new file mode 100644 index 0000000..f3deeff --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" +============================== +2D Sliced Wasserstein Distance +============================== + +This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +################################################################################### +# Compute Sliced Wasserstein distance for different seeds and number of projections +# ----------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +################################################################################### +# Plot Sliced Wasserstein Distance +# ----------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label="SWD") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0e6e2e2..ec3ede2 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,6 +39,7 @@ from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm +from .sliced import sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -50,4 +51,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets' 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance'] diff --git a/ot/sliced.py b/ot/sliced.py new file mode 100644 index 0000000..4792576 --- /dev/null +++ b/ot/sliced.py @@ -0,0 +1,144 @@ +""" +Sliced Wasserstein Distance. + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + + +import numpy as np + + +def get_random_projections(n_projections, d, seed=None): + r""" + Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + + Parameters + ---------- + n_projections : int + number of samples requested + d : int + dimension of the space + seed: int or RandomState, optional + Seed used for numpy random number generator + + Returns + ------- + out: ndarray, shape (n_projections, d) + The uniform unit vectors on the sphere + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_random_projections(n_projections, d) + >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + True + + """ + + if not isinstance(seed, np.random.RandomState): + random_state = np.random.RandomState(seed) + else: + random_state = seed + + projections = random_state.normal(0., 1., [n_projections, d]) + norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) + projections = projections / norm + return projections + + +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + + .. math:: + \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for numpy random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + """ + from .lp import emd2_1d + + X_s = np.asanyarray(X_s) + X_t = np.asanyarray(X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = np.full(n, 1 / n) + if b is None: + b = np.full(m, 1 / m) + + d = X_s.shape[1] + + projections = get_random_projections(n_projections, d, seed) + + X_s_projections = np.dot(projections, X_s.T) + X_t_projections = np.dot(projections, X_t.T) + + if log: + projected_emd = np.empty(n_projections) + else: + projected_emd = None + + res = 0. + + for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): + emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) + if projected_emd is not None: + projected_emd[i] = emd + res += emd + + res = (res / n_projections) ** 0.5 + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 0000000..a07d975 --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,85 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert len(projections) == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) -- cgit v1.2.3