summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-10-22 09:28:53 +0100
committerGitHub <noreply@github.com>2020-10-22 10:28:53 +0200
commit78b44af2434f494c8f9e4c8c91003fbc0e1d4415 (patch)
tree013002f0a65918cee5eb95648965d4361f0c3dc2
parent7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 (diff)
[MRG] Sliced wasserstein (#203)
* example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes * Implements sliced wasserstein * Changed formatting of string for py3.5 support * Docstest, expected 0.0 and not 0. * Adressed comments by @rflamary * No 3d plot here * add sliced to the docs * Incorporate comments by @rflamary * add link to pdf Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--README.md4
-rw-r--r--docs/source/all.rst1
-rw-r--r--examples/sliced-wasserstein/README.txt4
-rw-r--r--examples/sliced-wasserstein/plot_variance.py84
-rw-r--r--ot/__init__.py3
-rw-r--r--ot/sliced.py144
-rw-r--r--test/test_sliced.py85
7 files changed, 324 insertions, 1 deletions
diff --git a/README.md b/README.md
index e3598f1..6fe528a 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
+* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
POT provides the following Machine Learning related solvers:
@@ -180,6 +181,7 @@ The contributors to this library are
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
+* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+
+[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
diff --git a/docs/source/all.rst b/docs/source/all.rst
index d7b878f..f1f7075 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -27,6 +27,7 @@ API and modules
stochastic
unbalanced
partial
+ sliced
.. autosummary::
:toctree: ../modules/generated/
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
new file mode 100644
index 0000000..a575345
--- /dev/null
+++ b/examples/sliced-wasserstein/README.txt
@@ -0,0 +1,4 @@
+
+
+Sliced Wasserstein Distance
+--------------------------- \ No newline at end of file
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
new file mode 100644
index 0000000..f3deeff
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+2D Sliced Wasserstein Distance
+==============================
+
+This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+###################################################################################
+# Compute Sliced Wasserstein distance for different seeds and number of projections
+# -----------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###################################################################################
+# Plot Sliced Wasserstein Distance
+# -----------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label="SWD")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 0e6e2e2..ec3ede2 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -39,6 +39,7 @@ from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
+from .sliced import sliced_wasserstein_distance
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -50,4 +51,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets'
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance']
diff --git a/ot/sliced.py b/ot/sliced.py
new file mode 100644
index 0000000..4792576
--- /dev/null
+++ b/ot/sliced.py
@@ -0,0 +1,144 @@
+"""
+Sliced Wasserstein Distance.
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+
+import numpy as np
+
+
+def get_random_projections(n_projections, d, seed=None):
+ r"""
+ Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
+
+ Parameters
+ ----------
+ n_projections : int
+ number of samples requested
+ d : int
+ dimension of the space
+ seed: int or RandomState, optional
+ Seed used for numpy random number generator
+
+ Returns
+ -------
+ out: ndarray, shape (n_projections, d)
+ The uniform unit vectors on the sphere
+
+ Examples
+ --------
+ >>> n_projections = 100
+ >>> d = 5
+ >>> projs = get_random_projections(n_projections, d)
+ >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE
+ True
+
+ """
+
+ if not isinstance(seed, np.random.RandomState):
+ random_state = np.random.RandomState(seed)
+ else:
+ random_state = seed
+
+ projections = random_state.normal(0., 1., [n_projections, d])
+ norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True)
+ projections = projections / norm
+ return projections
+
+
+def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for numpy random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+ """
+ from .lp import emd2_1d
+
+ X_s = np.asanyarray(X_s)
+ X_t = np.asanyarray(X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = np.full(n, 1 / n)
+ if b is None:
+ b = np.full(m, 1 / m)
+
+ d = X_s.shape[1]
+
+ projections = get_random_projections(n_projections, d, seed)
+
+ X_s_projections = np.dot(projections, X_s.T)
+ X_t_projections = np.dot(projections, X_t.T)
+
+ if log:
+ projected_emd = np.empty(n_projections)
+ else:
+ projected_emd = None
+
+ res = 0.
+
+ for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
+ emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
+ if projected_emd is not None:
+ projected_emd[i] = emd
+ res += emd
+
+ res = (res / n_projections) ** 0.5
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/test/test_sliced.py b/test/test_sliced.py
new file mode 100644
index 0000000..a07d975
--- /dev/null
+++ b/test/test_sliced.py
@@ -0,0 +1,85 @@
+"""Tests for module sliced"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.sliced import get_random_projections
+
+
+def test_get_random_projections():
+ rng = np.random.RandomState(0)
+ projections = get_random_projections(1000, 50, rng)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)
+
+
+def test_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert len(projections) == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 1)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+ y = rng.randn(m, 1)
+ u = ot.utils.unif(m)
+ res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
+ expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
+ np.testing.assert_almost_equal(res ** 2, expected)