summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2021-11-02 14:19:57 +0100
committerGitHub <noreply@github.com>2021-11-02 14:19:57 +0100
commit6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch)
treec0ed5a7c297b4003688fec52d46f918ea0086a7d
parenta335324d008e8982be61d7ace937815a2bfa98f9 (diff)
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--README.md11
-rw-r--r--docs/source/readme.rst51
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py185
-rw-r--r--examples/backends/plot_wass1d_torch.py152
-rw-r--r--examples/sliced-wasserstein/plot_variance.py2
-rw-r--r--ot/__init__.py5
-rw-r--r--ot/backend.py98
-rw-r--r--ot/lp/__init__.py367
-rw-r--r--ot/lp/solver_1d.py367
-rw-r--r--ot/sliced.py181
-rw-r--r--test/test_1d_solver.py85
-rw-r--r--test/test_backend.py36
-rw-r--r--test/test_ot.py57
-rw-r--r--test/test_sliced.py90
-rw-r--r--test/test_utils.py2
15 files changed, 1244 insertions, 445 deletions
diff --git a/README.md b/README.md
index f0e5227..cfb9744 100644
--- a/README.md
+++ b/README.md
@@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
-* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
+* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021
-[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 82d3e6c..ee32e2b 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples):
for regularized OT [7].
- Entropic regularization OT solver with `Sinkhorn Knopp
Algorithm <auto_examples/plot_OT_1D.html>`__
- [2] , stabilized version [9] [10], greedy Sinkhorn [22] and
+ [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and
`Screening Sinkhorn
[26] <auto_examples/plot_screenkhorn_1D.html>`__.
- Bregman projections for `Wasserstein
@@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples):
solver <auto_examples/plot_stochastic.html>`__
for Large-scale Optimal Transport (semi-dual problem [18] and dual
problem [19])
+- `Stochastic solver of Gromov
+ Wasserstein <auto_examples/gromov/plot_gromov.html>`__
+ for large-scale problem with any loss functions [33]
- Non regularized `free support Wasserstein
barycenters <auto_examples/barycenters/plot_free_support_barycenter.html>`__
[20].
@@ -137,19 +140,12 @@ following Python modules:
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23) (build only, not necessary when installing wheels
- from pip or conda)
+- Cython (>=0.23) (build only, not necessary when installing from pip
+ or conda)
Pip installation
^^^^^^^^^^^^^^^^
-Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
-be installed prior to installing POT. This can be done easily with
-
-.. code:: console
-
- pip install numpy cython
-
You can install the toolbox through PyPI with:
.. code:: console
@@ -183,7 +179,8 @@ without errors:
import ot
-Note that for easier access the module is name ot instead of pot.
+Note that for easier access the module is named ``ot`` instead of
+``pot``.
Dependencies
~~~~~~~~~~~~
@@ -222,7 +219,7 @@ Short examples
.. code:: python
- # a and b are 1D histograms (sum to 1 and positive)
+ # a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
Wd = ot.emd2(a, b, M) # exact linear program
Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
@@ -232,7 +229,7 @@ Short examples
.. code:: python
- # a and b are 1D histograms (sum to 1 and positive)
+ # a,b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
T = ot.emd(a, b, M) # exact linear program
T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
@@ -287,6 +284,10 @@ The contributors to this library are
- `Ievgen Redko <https://ievred.github.io/>`__ (Laplacian DA, JCPOT)
- `Adrien Corenflos <https://adriencorenflos.github.io/>`__ (Sliced
Wasserstein Distance)
+- `Tanguy Kerdoncuff <https://hv0nnus.github.io/>`__ (Sampled Gromov
+ Wasserstein)
+- `Minhui Huang <https://mhhuang95.github.io>`__ (Projection Robust
+ Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
@@ -476,6 +477,30 @@ of
measures <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`__,
Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate
+Descent Method for Computing the Projection Robust Wasserstein
+Distance <http://proceedings.mlr.press/v139/huang21e.html>`__,
+Proceedings of the 38th International Conference on Machine Learning
+(ICML).
+
+[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov
+Wasserstein <https://hal.archives-ouvertes.fr/hal-03232509/document>`__,
+Machine Learning Journal (MJL), 2021
+
+[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A.,
+& Peyré, G. (2019, April). `Interpolating between optimal transport and
+MMD using Sinkhorn
+divergences <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`__.
+In The 22nd International Conference on Artificial Intelligence and
+Statistics (pp. 2681-2690). PMLR.
+
+[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N.,
+Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein
+distance and its use for
+gans <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`__.
+In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
+Recognition (pp. 10648-10656).
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
new file mode 100644
index 0000000..05b9952
--- /dev/null
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -0,0 +1,185 @@
+r"""
+=================================
+Sliced Wasserstein barycenter and gradient flow with PyTorch
+=================================
+
+In this exemple we use the pytorch backend to optimize the sliced Wasserstein
+loss between two empirical distributions [31].
+
+In the first example one we perform a
+gradient flow on the support of a distribution that minimize the sliced
+Wassersein distance as poposed in [36].
+
+In the second exemple we optimize with a gradient descent the sliced
+Wasserstein barycenter between two distributions as in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
+(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling
+via optimal transport and diffusions. In International Conference on
+Machine Learning (pp. 4104-4113). PMLR.
+
+
+"""
+# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+# %%
+# Loading the data
+
+
+import numpy as np
+import matplotlib.pylab as pl
+import torch
+import ot
+import matplotlib.animation as animation
+
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
+
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
+
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
+
+pl.figure(1, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+
+# %%
+# Sliced Wasserstein gradient flow with Pytorch
+# ---------------------------------------------
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True)
+x2_torch = torch.tensor(x2).to(device=device)
+
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+for i in range(nb_iter_max):
+
+ loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = x1_torch.grad
+ x1_torch -= grad * lr / (1 + i / 5e1) # step
+ x1_torch.grad.zero_()
+ x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy()
+
+xb = x1_torch.clone().detach().cpu().numpy()
+
+pl.figure(2, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$')
+pl.title('Sliced Wasserstein gradient flow')
+pl.legend()
+ax = pl.axis()
+
+# %%
+# Animate trajectories of the gradient flow along iteration
+# -------------------------------------------------------
+
+pl.figure(3, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
+
+# %%
+# Compute the Sliced Wasserstein Barycenter
+#
+x1_torch = torch.tensor(x1).to(device=device)
+x3_torch = torch.tensor(x3).to(device=device)
+xbinit = np.random.randn(500, 2) * 10 + 16
+xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
+
+lr = 1e3
+nb_iter_max = 100
+
+x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
+
+loss_iter = []
+
+# generator for random permutations
+gen = torch.Generator()
+gen.manual_seed(42)
+
+alpha = 0.5
+
+for i in range(nb_iter_max):
+
+ loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \
+ + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen)
+
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = xbary_torch.grad
+ xbary_torch -= grad * lr # / (1 + i / 5e1) # step
+ xbary_torch.grad.zero_()
+ x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()
+
+xb = xbary_torch.clone().detach().cpu().numpy()
+
+pl.figure(4, (8, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$')
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter')
+pl.title('Sliced Wasserstein barycenter')
+pl.legend()
+ax = pl.axis()
+
+
+# %%
+# Animate trajectories of the barycenter along gradient descent
+# -------------------------------------------------------
+
+pl.figure(5, (8, 4))
+
+
+def _update_plot(i):
+ pl.clf()
+ pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$')
+ pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$')
+ pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$')
+ pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i))
+ pl.axis(ax)
+ return 1
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000)
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
new file mode 100644
index 0000000..0abdd6d
--- /dev/null
+++ b/examples/backends/plot_wass1d_torch.py
@@ -0,0 +1,152 @@
+r"""
+=================================
+Wasserstein 1D with PyTorch
+=================================
+
+In this small example, we consider the following minization problem:
+
+.. math::
+ \mu^* = \min_\mu W(\mu,\nu)
+
+where :math:`\nu` is a reference 1D measure. The problem is handled
+by a projected gradient descent method, where the gradient is computed
+by pyTorch automatic differentiation. The projection on the simplex
+ensures that the iterate will remain on the probability simplex.
+
+This example illustrates both `wasserstein_1d` function and backend use within
+the POT framework.
+"""
+# Author: Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import matplotlib as mpl
+import torch
+
+from ot.lp import wasserstein_1d
+from ot.datasets import make_1D_gauss as gauss
+from ot.utils import proj_simplex
+
+red = np.array(mpl.colors.to_rgb('red'))
+blue = np.array(mpl.colors.to_rgb('blue'))
+
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# enforce sum to one on the support
+a = a / a.sum()
+b = b / b.sum()
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device).requires_grad_(True)
+b_torch = torch.tensor(b).to(device=device)
+
+lr = 1e-6
+nb_iter_max = 800
+
+loss_iter = []
+
+pl.figure(1, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = a_torch.grad
+ a_torch -= a_torch.grad * lr # step
+ a_torch.grad.zero_()
+ a_torch.data = proj_simplex(a_torch) # projection onto the simplex
+
+ # plot one curve every 10 iterations
+ if i % 10 == 0:
+ mix = float(i) / nb_iter_max
+ pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red)
+
+pl.legend()
+pl.title('Distribution along the iterations of the projected gradient descent')
+pl.show()
+
+pl.figure(2)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
+
+# %%
+# Wasserstein barycenter
+# ---------
+# In this example, we consider the following Wasserstein barycenter problem
+# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$
+# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t`
+# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient
+# descent method, where the gradient is computed by pyTorch automatic differentiation.
+# The projection on the simplex ensures that the iterate will remain on the
+# probability simplex.
+#
+# This example illustrates both `wasserstein_1d` function and backend use within the
+# POT framework.
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# use pyTorch for our data
+x_torch = torch.tensor(x).to(device=device)
+a_torch = torch.tensor(a).to(device=device)
+b_torch = torch.tensor(b).to(device=device)
+bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)
+
+
+lr = 1e-6
+nb_iter_max = 2000
+
+loss_iter = []
+
+# instant of the interpolation
+t = 0.5
+
+for i in range(nb_iter_max):
+ # Compute the Wasserstein 1D with torch backend
+ loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)
+ # record the corresponding loss value
+ loss_iter.append(loss.clone().detach().cpu().numpy())
+ loss.backward()
+
+ # performs a step of projected gradient descent
+ with torch.no_grad():
+ grad = bary_torch.grad
+ bary_torch -= bary_torch.grad * lr # step
+ bary_torch.grad.zero_()
+ bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex
+
+pl.figure(3, figsize=(8, 4))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter')
+pl.legend()
+pl.title('Wasserstein barycenter computed by gradient descent')
+pl.show()
+
+pl.figure(4)
+pl.plot(range(nb_iter_max), loss_iter, lw=3)
+pl.title('Evolution of the loss along iterations', fontsize=16)
+pl.show()
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 27df656..7d73907 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -63,7 +63,7 @@ res = np.empty((n_seed, 25))
# %% Compute statistics
for seed in range(n_seed):
for i, n_projections in enumerate(n_projections_arr):
- res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed)
res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)
diff --git a/ot/__init__.py b/ot/__init__.py
index 5bd4bab..f20332c 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -42,7 +42,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
-from .sliced import sliced_wasserstein_distance
+from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -51,8 +51,9 @@ __version__ = "0.8.0dev"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+ 'emd2_1d', 'wasserstein_1d', 'backend',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'max_sliced_wasserstein_distance',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/backend.py b/ot/backend.py
index 358297c..d3df44c 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -103,6 +103,8 @@ class Backend():
__name__ = None
__type__ = None
+ rng_ = None
+
def __str__(self):
return self.__name__
@@ -540,6 +542,36 @@ class Backend():
"""
raise NotImplementedError()
+ def seed(self, seed=None):
+ r"""
+ Sets the seed for the random generator.
+
+ This function follows the api from :any:`numpy.random.seed`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
+ """
+ raise NotImplementedError()
+
+ def rand(self, *size, type_as=None):
+ r"""
+ Generate uniform random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def randn(self, *size, type_as=None):
+ r"""
+ Generate normal Gaussian random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
r"""
Creates a sparse tensor in COOrdinate format.
@@ -632,6 +664,8 @@ class NumpyBackend(Backend):
__name__ = 'numpy'
__type__ = np.ndarray
+ rng_ = np.random.RandomState()
+
def to_numpy(self, a):
return a
@@ -793,6 +827,16 @@ class NumpyBackend(Backend):
def reshape(self, a, shape):
return np.reshape(a, shape)
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ return self.rng_.rand(*size)
+
+ def randn(self, *size, type_as=None):
+ return self.rng_.randn(*size)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return coo_matrix((data, (rows, cols)), shape=shape)
@@ -845,6 +889,11 @@ class JaxBackend(Backend):
__name__ = 'jax'
__type__ = jax_type
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = jax.random.PRNGKey(42)
+
def to_numpy(self, a):
return np.array(a)
@@ -1010,6 +1059,24 @@ class JaxBackend(Backend):
def reshape(self, a, shape):
return jnp.reshape(a, shape)
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_ = jax.random.PRNGKey(seed)
+
+ def rand(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.uniform(subkey, shape=size)
+
+ def randn(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.normal(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.normal(subkey, shape=size)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
# Currently, JAX does not support sparse matrices
data = self.to_numpy(data)
@@ -1064,8 +1131,13 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ rng_ = None
+
def __init__(self):
+ self.rng_ = torch.Generator()
+ self.rng_.seed()
+
from torch.autograd import Function
# define a function that takes inputs val and grads
@@ -1102,12 +1174,16 @@ class TorchBackend(Backend):
return res
def zeros(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.zeros(shape)
else:
return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
def ones(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.ones(shape)
else:
@@ -1120,6 +1196,8 @@ class TorchBackend(Backend):
return torch.arange(start, stop, step, device=type_as.device)
def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.full(shape, fill_value)
else:
@@ -1293,6 +1371,26 @@ class TorchBackend(Backend):
def reshape(self, a, shape):
return torch.reshape(a, shape)
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_.manual_seed(seed)
+ elif isinstance(seed, torch.Generator):
+ self.rng_ = seed
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ else:
+ return torch.rand(size=size, generator=self.rng_)
+
+ def randn(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ else:
+ return torch.randn(size=size, generator=self.rng_)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 4e95ccf..2c18a88 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -13,20 +13,23 @@ import multiprocessing
import sys
import numpy as np
-from scipy.sparse import coo_matrix
import warnings
from . import cvx
from .cvx import barycenter
+
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
+from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
+
def check_number_threads(numThreads):
"""Checks whether or not the requested number of threads has a valid value.
@@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
.. warning::
This function is necessary because the C++ solver in emd_c
- discards all samples in the distributions with
- zeros weights. This means that while the primal variable (transport
+ discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
- on the samples with weights different from zero.
+ on the samples with weights different from zero.
First we compute the constraints violations:
@@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
format
.. note:: This function is backend-compatible and will work on arrays
- from all compatible backends.
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) array-like, float
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) array-like, float
- Target histogram (uniform weight if empty list)
- M : (ns,nt) array-like, float
- Loss matrix (c-order array in numpy with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual variables.
- Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
@@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
Returns
-------
- gamma: array-like, shape (ns, nt)
+ gamma: array-like, shape (ns, nt)
Optimal transportation matrix for the given
- parameters
+ parameters
log: dict, optional
If input log is true, a dictionary containing the
cost and dual variables and exit status
@@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
regularized OT"""
# convert to numpy if list
- a, b, M = list_to_array(a, b, M)
+ a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
- nx = get_backend(M0, a0, b0)
+ nx = get_backend(M0, a0, b0)
# convert to numpy
M = nx.to_numpy(M)
@@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
"Dimension mismatch, check dimensions of M with a and b"
# ensure that same mass
- np.testing.assert_almost_equal(a.sum(0),
- b.sum(0), err_msg='a and b vector must have the same sum')
- b=b*a.sum()/b.sum()
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b = b * a.sum() / b.sum()
asel = a != 0
bsel = b != 0
@@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1,
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
- a, b, M = list_to_array(a, b, M)
+ a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
- nx = get_backend(M0, a0, b0)
+ nx = get_backend(M0, a0, b0)
# convert to numpy
M = nx.to_numpy(M)
@@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1,
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64, order= 'C')
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1,
log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0,b0, M0), (log['u'], log['v'], G))
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1,
G = nx.from_numpy(G, type_as=M0)
cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0),G))
+ (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0), G))
check_result(result_code)
return cost
@@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
-
-
-def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the OT matrix
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost.
- Otherwise returns only the optimal transportation matrix.
-
- Returns
- -------
- gamma: (ns, nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is True, a dictionary containing the cost
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd_1d(x_a, x_b, a, b)
- array([[0. , 0.5],
- [0.5, 0. ]])
- >>> ot.emd_1d(x_a, x_b)
- array([[0. , 0.5],
- [0.5, 0. ]])
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd : EMD for multidimensional distributions
- ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
- transportation matrix)
- """
- a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
- nx = get_backend(x_a, x_b)
-
- assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
- assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
-
- # if empty array given then use uniform distributions
- if a is None or a.ndim == 0 or len(a) == 0:
- a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
- if b is None or b.ndim == 0 or len(b) == 0:
- b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
-
- # ensure that same mass
- np.testing.assert_almost_equal(
- nx.sum(a, axis=0),
- nx.sum(b, axis=0),
- err_msg='a and b vector must have the same sum'
- )
- b = b * nx.sum(a) / nx.sum(b)
-
- x_a_1d = nx.reshape(x_a, (-1,))
- x_b_1d = nx.reshape(x_b, (-1,))
- perm_a = nx.argsort(x_a_1d)
- perm_b = nx.argsort(x_b_1d)
-
- G_sorted, indices, cost = emd_1d_sorted(
- nx.to_numpy(a[perm_a]),
- nx.to_numpy(b[perm_b]),
- nx.to_numpy(x_a_1d[perm_a]),
- nx.to_numpy(x_b_1d[perm_b]),
- metric=metric, p=p
- )
-
- G = nx.coo_matrix(
- G_sorted,
- perm_a[indices[:, 0]],
- perm_b[indices[:, 1]],
- shape=(a.shape[0], b.shape[0]),
- type_as=x_a
- )
- if dense:
- G = nx.todense(G)
- elif str(nx) == "jax":
- warnings.warn("JAX does not support sparse matrices, converting to dense")
- if log:
- log = {'cost': cost}
- return G, log
- return G
-
-
-def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the loss
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Only used if log is set to True. Due to implementation details,
- this function runs faster when dense is set to False.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the transportation matrix.
- Otherwise returns only the loss.
-
- Returns
- -------
- loss: float
- Cost associated to the optimal transportation
- log: dict
- If input log is True, a dictionary containing the Optimal transportation
- matrix for the given parameters
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd2_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd2_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.emd2_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd2 : EMD for multidimensional distributions
- ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
- instead of the cost)
- """
- # If we do not return G (log==False), then we should not to cast it to dense
- # (useless overhead)
- G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
- dense=dense and log, log=True)
- cost = log_emd['cost']
- if log:
- log_emd = {'G': G}
- return cost, log_emd
- return cost
-
-
-def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
- r"""Solves the p-Wasserstein distance problem between 1d measures and returns
- the distance
-
- .. math::
- \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p}
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
-
- where :
-
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- p: float, optional (default=1.0)
- The order of the p-Wasserstein distance to be computed
-
- Returns
- -------
- dist: float
- p-Wasserstein distance
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function wasserstein_1d accepts
- lists and performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.wasserstein_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.wasserstein_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd_1d : EMD for 1d distributions
- """
- cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
- dense=False, log=False)
- return np.power(cost_emd, 1. / p)
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
new file mode 100644
index 0000000..42554aa
--- /dev/null
+++ b/ot/lp/solver_1d.py
@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+"""
+Exact solvers for the 1D Wasserstein distance using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+
+from .emd_wrap import emd_1d_sorted
+from ..backend import get_backend
+from ..utils import list_to_array
+
+
+def quantile_function(qs, cws, xs):
+ r""" Computes the quantile function of an empirical distribution
+
+ Parameters
+ ----------
+ qs: array-like, shape (n,)
+ Quantiles at which the quantile function is evaluated
+ cws: array-like, shape (m, ...)
+ cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
+ xs: array-like, shape (n, ...)
+ locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions
+
+ Returns
+ -------
+ q: array-like, shape (..., n)
+ The quantiles of the distribution
+ """
+ nx = get_backend(qs, cws)
+ n = xs.shape[0]
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ cws = cws.T.contiguous()
+ qs = qs.T.contiguous()
+ else:
+ cws = cws.T
+ qs = qs.T
+ idx = nx.searchsorted(cws, qs).T
+ return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)
+
+
+def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
+ r"""
+ Computes the 1 dimensional OT loss [15] between two (batched) empirical
+ distributions
+
+ .. math:
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+
+ It is formally the p-Wasserstein distance raised to the power p.
+ We do so in a vectorized way by first building the individual quantile functions then integrating them.
+
+ This function should be preferred to `emd_1d` whenever the backend is
+ different to numpy, and when gradients over
+ either sample positions or weights are required.
+
+ Parameters
+ ----------
+ u_values: array-like, shape (n, ...)
+ locations of the first empirical distribution
+ v_values: array-like, shape (m, ...)
+ locations of the second empirical distribution
+ u_weights: array-like, shape (n, ...), optional
+ weights of the first empirical distribution, if None then uniform weights are used
+ v_weights: array-like, shape (m, ...), optional
+ weights of the second empirical distribution, if None then uniform weights are used
+ p: int, optional
+ order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
+ require_sort: bool, optional
+ sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
+ the function, default is True
+
+ Returns
+ -------
+ cost: float/array-like, shape (...)
+ the batched EMD
+
+ References
+ ----------
+ .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.
+
+ """
+
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cumweights = nx.cumsum(u_weights, 0)
+ v_cumweights = nx.cumsum(v_weights, 0)
+
+ qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0)
+ u_quantiles = quantile_function(qs, u_cumweights, u_values)
+ v_quantiles = quantile_function(qs, v_cumweights, v_values)
+ qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
+ delta = qs[1:, ...] - qs[:-1, ...]
+ diff_quantiles = nx.abs(u_quantiles - v_quantiles)
+
+ if p == 1:
+ return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
+
+
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+ """
+ a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
+ nx = get_backend(x_a, x_b)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+
+ # if empty array given then use uniform distributions
+ if a is None or a.ndim == 0 or len(a) == 0:
+ a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
+ if b is None or b.ndim == 0 or len(b) == 0:
+ b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
+
+ # ensure that same mass
+ np.testing.assert_almost_equal(
+ nx.sum(a, axis=0),
+ nx.sum(b, axis=0),
+ err_msg='a and b vector must have the same sum'
+ )
+ b = b * nx.sum(a) / nx.sum(b)
+
+ x_a_1d = nx.reshape(x_a, (-1,))
+ x_b_1d = nx.reshape(x_b, (-1,))
+ perm_a = nx.argsort(x_a_1d)
+ perm_b = nx.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(
+ nx.to_numpy(a[perm_a]),
+ nx.to_numpy(b[perm_b]),
+ nx.to_numpy(x_a_1d[perm_a]),
+ nx.to_numpy(x_b_1d[perm_b]),
+ metric=metric, p=p
+ )
+
+ G = nx.coo_matrix(
+ G_sorted,
+ perm_a[indices[:, 0]],
+ perm_b[indices[:, 1]],
+ shape=(a.shape[0], b.shape[0]),
+ type_as=x_a
+ )
+ if dense:
+ G = nx.todense(G)
+ elif str(nx) == "jax":
+ warnings.warn("JAX does not support sparse matrices, converting to dense")
+ if log:
+ log = {'cost': cost}
+ return G, log
+ return G
+
+
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost
diff --git a/ot/sliced.py b/ot/sliced.py
index 4792576..d3dc3f2 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -1,61 +1,73 @@
"""
-Sliced Wasserstein Distance.
+Sliced OT Distances
"""
# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
import numpy as np
+from .backend import get_backend, NumpyBackend
+from .utils import list_to_array
-def get_random_projections(n_projections, d, seed=None):
+def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
r"""
Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
Parameters
----------
- n_projections : int
- number of samples requested
d : int
dimension of the space
+ n_projections : int
+ number of samples requested
seed: int or RandomState, optional
Seed used for numpy random number generator
+ backend:
+ Backend to ue for random generation
Returns
-------
- out: ndarray, shape (n_projections, d)
+ out: ndarray, shape (d, n_projections)
The uniform unit vectors on the sphere
Examples
--------
>>> n_projections = 100
>>> d = 5
- >>> projs = get_random_projections(n_projections, d)
- >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE
+ >>> projs = get_random_projections(d, n_projections)
+ >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE
True
"""
- if not isinstance(seed, np.random.RandomState):
- random_state = np.random.RandomState(seed)
+ if backend is None:
+ nx = NumpyBackend()
+ else:
+ nx = backend
+
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ projections = seed.randn(d, n_projections)
else:
- random_state = seed
+ if seed is not None:
+ nx.seed(seed)
+ projections = nx.randn(d, n_projections, type_as=type_as)
- projections = random_state.normal(0., 1., [n_projections, d])
- norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True)
- projections = projections / norm
+ projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True))
return projections
-def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
+def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
r"""
- Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance
+ Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
.. math::
- \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}
+ \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}}
where :
@@ -74,8 +86,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed
samples weights in the target domain
n_projections : int, optional
Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
seed: int or RandomState or None, optional
- Seed used for numpy random number generator
+ Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
@@ -100,10 +116,18 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed
.. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
"""
- from .lp import emd2_1d
+ from .lp import wasserstein_1d
- X_s = np.asanyarray(X_s)
- X_t = np.asanyarray(X_t)
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
+ else:
+ nx = get_backend(X_s, X_t)
n = X_s.shape[0]
m = X_t.shape[0]
@@ -114,31 +138,120 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed
X_t.shape[1]))
if a is None:
- a = np.full(n, 1 / n)
+ a = nx.full(n, 1 / n)
if b is None:
- b = np.full(m, 1 / m)
+ b = nx.full(m, 1 / m)
d = X_s.shape[1]
- projections = get_random_projections(n_projections, d, seed)
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
- X_s_projections = np.dot(projections, X_s.T)
- X_t_projections = np.dot(projections, X_t.T)
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
+ res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p)
if log:
- projected_emd = np.empty(n_projections)
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in
+ \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\#
+ \mu, \theta_\# \nu)]^{\frac{1}{p}}
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+ """
+ from .lp import wasserstein_1d
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
else:
- projected_emd = None
+ nx = get_backend(X_s, X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = nx.full(n, 1 / n)
+ if b is None:
+ b = nx.full(m, 1 / m)
+
+ d = X_s.shape[1]
+
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
- res = 0.
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
- for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
- emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
- if projected_emd is not None:
- projected_emd[i] = emd
- res += emd
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
- res = (res / n_projections) ** 0.5
+ res = nx.max(projected_emd) ** (1.0 / p)
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
new file mode 100644
index 0000000..2c470c2
--- /dev/null
+++ b/test/test_1d_solver.py
@@ -0,0 +1,85 @@
+"""Tests for module 1d Wasserstein solver"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.lp import wasserstein_1d
+
+from ot.backend import get_backend_list
+from scipy.stats import wasserstein_distance
+
+backend_list = get_backend_list()
+
+
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d(nx):
+ from scipy.stats import wasserstein_distance
+
+ rng = np.random.RandomState(0)
+
+ n = 100
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ xb = nx.from_numpy(x)
+ rho_ub = nx.from_numpy(rho_u)
+ rho_vb = nx.from_numpy(rho_v)
+
+ # test 1 : wasserstein_1d should be close to scipy W_1 implementation
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
+ wasserstein_distance(x, x, rho_u, rho_v))
+
+ # test 2 : wasserstein_1d should be close to one when only translating the support
+ np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2),
+ 1.)
+
+ # test 3 : arrays test
+ X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1)
+ Xb = nx.from_numpy(X)
+ res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
+ np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
diff --git a/test/test_backend.py b/test/test_backend.py
index 0f11ace..1832b91 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -208,6 +208,11 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.reshape(M, (5, 3, 2))
with pytest.raises(NotImplementedError):
+ nx.seed(42)
+ with pytest.raises(NotImplementedError):
+ nx.rand()
+ with pytest.raises(NotImplementedError):
+ nx.randn()
nx.coo_matrix(M, M, M)
with pytest.raises(NotImplementedError):
nx.issparse(M)
@@ -248,6 +253,7 @@ def test_func_backends(nx):
Mb = nx.from_numpy(M)
vb = nx.from_numpy(v)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -255,6 +261,7 @@ def test_func_backends(nx):
sp_datab = nx.from_numpy(sp_data)
A = nx.set_gradients(val, v, v)
+
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')
@@ -505,6 +512,35 @@ def test_func_backends(nx):
assert np.allclose(a1, a2, atol=1e-7)
+def test_random_backends(nx):
+
+ tmp_u = nx.rand()
+
+ assert tmp_u < 1
+
+ tmp_n = nx.randn()
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.rand(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n))
+
+ assert np.all(M1 >= 0)
+ assert np.all(M1 < 1)
+ assert M1.shape == (5, 2)
+ assert np.allclose(M1, M2)
+
+ nx.seed(0)
+ M1 = nx.to_numpy(nx.randn(5, 2))
+ nx.seed(0)
+ M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u))
+
+ nx.seed(42)
+ v1 = nx.randn()
+ v2 = nx.randn()
+ assert v1 != v2
+
+
def test_gradients_backends():
rnd = np.random.RandomState(0)
diff --git a/test/test_ot.py b/test/test_ot.py
index 4dfc510..5bfde1d 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -8,11 +8,11 @@ import warnings
import numpy as np
import pytest
-from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
+from scipy.stats import wasserstein_distance
def test_emd_dimension_and_mass_mismatch():
@@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])
-def test_emd_1d_emd2_1d_with_weights():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- w_u = rng.uniform(0., 1., n)
- w_u = w_u / w_u.sum()
-
- w_v = rng.uniform(0., 1., m)
- w_v = w_v / w_v.sum()
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd(w_u, w_v, M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
-
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
-
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
- np.testing.assert_allclose(wass_sp, wass1d_euc)
-
- # check constraints
- np.testing.assert_allclose(w_u, G.sum(1))
- np.testing.assert_allclose(w_v, G.sum(0))
-
-
-def test_wass_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
-
- wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
-
- # check loss is similar
- np.testing.assert_allclose(np.sqrt(wass), wass1d)
-
-
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
diff --git a/test/test_sliced.py b/test/test_sliced.py
index a07d975..0bd74ec 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -1,6 +1,7 @@
"""Tests for module sliced"""
# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -14,7 +15,7 @@ from ot.sliced import get_random_projections
def test_get_random_projections():
rng = np.random.RandomState(0)
projections = get_random_projections(1000, 50, rng)
- np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.)
def test_sliced_same_dist():
@@ -48,12 +49,12 @@ def test_sliced_log():
y = rng.randn(n, 4)
u = ot.utils.unif(n)
- res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True)
assert len(log) == 2
projections = log["projections"]
projected_emds = log["projected_emds"]
- assert len(projections) == len(projected_emds) == 10
+ assert projections.shape[1] == len(projected_emds) == 10
for emd in projected_emds:
assert emd > 0
@@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd():
res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
np.testing.assert_almost_equal(res ** 2, expected)
+
+
+def test_max_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_max_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert res > 0.
+
+
+def test_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
+
+
+def test_max_sliced_backend(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(2 * n, 2)
+
+ P = rng.randn(2, 20)
+ P = P / np.sqrt((P**2).sum(0, keepdims=True))
+
+ n_projections = 20
+
+ xb = nx.from_numpy(x)
+ yb = nx.from_numpy(y)
+ Pb = nx.from_numpy(P)
+
+ val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
+
+ val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+ val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0)
+
+ assert val > 0
+ assert val == val2
+
+ valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb))
+
+ assert np.allclose(val0, valb)
diff --git a/test/test_utils.py b/test/test_utils.py
index 0650ce2..40f4e49 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -109,7 +109,7 @@ def test_dist():
D2 = ot.dist(x, x)
D3 = ot.dist(x)
- D4 = ot.dist(x, x, metric='minkowski', p=0.5)
+ D4 = ot.dist(x, x, metric='minkowski', p=2)
assert D4[0, 1] == D4[1, 0]