diff options
-rw-r--r-- | README.md | 11 | ||||
-rw-r--r-- | docs/source/readme.rst | 51 | ||||
-rw-r--r-- | examples/backends/plot_sliced_wass_grad_flow_pytorch.py | 185 | ||||
-rw-r--r-- | examples/backends/plot_wass1d_torch.py | 152 | ||||
-rw-r--r-- | examples/sliced-wasserstein/plot_variance.py | 2 | ||||
-rw-r--r-- | ot/__init__.py | 5 | ||||
-rw-r--r-- | ot/backend.py | 98 | ||||
-rw-r--r-- | ot/lp/__init__.py | 367 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 367 | ||||
-rw-r--r-- | ot/sliced.py | 181 | ||||
-rw-r--r-- | test/test_1d_solver.py | 85 | ||||
-rw-r--r-- | test/test_backend.py | 36 | ||||
-rw-r--r-- | test/test_ot.py | 57 | ||||
-rw-r--r-- | test/test_sliced.py | 90 | ||||
-rw-r--r-- | test/test_utils.py | 2 |
15 files changed, 1244 insertions, 445 deletions
@@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on +Machine Learning (pp. 4104-4113). PMLR. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 82d3e6c..ee32e2b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples): for regularized OT [7]. - Entropic regularization OT solver with `Sinkhorn Knopp Algorithm <auto_examples/plot_OT_1D.html>`__ - [2] , stabilized version [9] [10], greedy Sinkhorn [22] and + [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and `Screening Sinkhorn [26] <auto_examples/plot_screenkhorn_1D.html>`__. - Bregman projections for `Wasserstein @@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples): solver <auto_examples/plot_stochastic.html>`__ for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +- `Stochastic solver of Gromov + Wasserstein <auto_examples/gromov/plot_gromov.html>`__ + for large-scale problem with any loss functions [33] - Non regularized `free support Wasserstein barycenters <auto_examples/barycenters/plot_free_support_barycenter.html>`__ [20]. @@ -137,19 +140,12 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing wheels - from pip or conda) +- Cython (>=0.23) (build only, not necessary when installing from pip + or conda) Pip installation ^^^^^^^^^^^^^^^^ -Note that due to a limitation of pip, ``cython`` and ``numpy`` need to -be installed prior to installing POT. This can be done easily with - -.. code:: console - - pip install numpy cython - You can install the toolbox through PyPI with: .. code:: console @@ -183,7 +179,8 @@ without errors: import ot -Note that for easier access the module is name ot instead of pot. +Note that for easier access the module is named ``ot`` instead of +``pot``. Dependencies ~~~~~~~~~~~~ @@ -222,7 +219,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix Wd = ot.emd2(a, b, M) # exact linear program Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT @@ -232,7 +229,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix T = ot.emd(a, b, M) # exact linear program T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT @@ -287,6 +284,10 @@ The contributors to this library are - `Ievgen Redko <https://ievred.github.io/>`__ (Laplacian DA, JCPOT) - `Adrien Corenflos <https://adriencorenflos.github.io/>`__ (Sliced Wasserstein Distance) +- `Tanguy Kerdoncuff <https://hv0nnus.github.io/>`__ (Sampled Gromov + Wasserstein) +- `Minhui Huang <https://mhhuang95.github.io>`__ (Projection Robust + Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various @@ -476,6 +477,30 @@ of measures <https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf>`__, Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate +Descent Method for Computing the Projection Robust Wasserstein +Distance <http://proceedings.mlr.press/v139/huang21e.html>`__, +Proceedings of the 38th International Conference on Machine Learning +(ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov +Wasserstein <https://hal.archives-ouvertes.fr/hal-03232509/document>`__, +Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., +& Peyré, G. (2019, April). `Interpolating between optimal transport and +MMD using Sinkhorn +divergences <http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf>`__. +In The 22nd International Conference on Artificial Intelligence and +Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., +Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein +distance and its use for +gans <https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf>`__. +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern +Recognition (pp. 10648-10656). + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py new file mode 100644 index 0000000..05b9952 --- /dev/null +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -0,0 +1,185 @@ +r""" +================================= +Sliced Wasserstein barycenter and gradient flow with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the sliced Wasserstein +loss between two empirical distributions [31]. + +In the first example one we perform a +gradient flow on the support of a distribution that minimize the sliced +Wassersein distance as poposed in [36]. + +In the second exemple we optimize with a gradient descent the sliced +Wasserstein barycenter between two distributions as in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions. In International Conference on +Machine Learning (pp. 4104-4113). PMLR. + + +""" +# Author: Rémi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + + +# %% +# Loading the data + + +import numpy as np +import matplotlib.pylab as pl +import torch +import ot +import matplotlib.animation as animation + +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] + +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) + +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 + +pl.figure(1, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) + +# %% +# Sliced Wasserstein gradient flow with Pytorch +# --------------------------------------------- + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True) +x2_torch = torch.tensor(x2).to(device=device) + + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +for i in range(nb_iter_max): + + loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = x1_torch.grad + x1_torch -= grad * lr / (1 + i / 5e1) # step + x1_torch.grad.zero_() + x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() + +xb = x1_torch.clone().detach().cpu().numpy() + +pl.figure(2, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') +pl.title('Sliced Wasserstein gradient flow') +pl.legend() +ax = pl.axis() + +# %% +# Animate trajectories of the gradient flow along iteration +# ------------------------------------------------------- + +pl.figure(3, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) + +# %% +# Compute the Sliced Wasserstein Barycenter +# +x1_torch = torch.tensor(x1).to(device=device) +x3_torch = torch.tensor(x3).to(device=device) +xbinit = np.random.randn(500, 2) * 10 + 16 +xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +alpha = 0.5 + +for i in range(nb_iter_max): + + loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ + + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = xbary_torch.grad + xbary_torch -= grad * lr # / (1 + i / 5e1) # step + xbary_torch.grad.zero_() + x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy() + +xb = xbary_torch.clone().detach().cpu().numpy() + +pl.figure(4, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') +pl.title('Sliced Wasserstein barycenter') +pl.legend() +ax = pl.axis() + + +# %% +# Animate trajectories of the barycenter along gradient descent +# ------------------------------------------------------- + +pl.figure(5, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py new file mode 100644 index 0000000..0abdd6d --- /dev/null +++ b/examples/backends/plot_wass1d_torch.py @@ -0,0 +1,152 @@ +r""" +================================= +Wasserstein 1D with PyTorch +================================= + +In this small example, we consider the following minization problem: + +.. math:: + \mu^* = \min_\mu W(\mu,\nu) + +where :math:`\nu` is a reference 1D measure. The problem is handled +by a projected gradient descent method, where the gradient is computed +by pyTorch automatic differentiation. The projection on the simplex +ensures that the iterate will remain on the probability simplex. + +This example illustrates both `wasserstein_1d` function and backend use within +the POT framework. +""" +# Author: Nicolas Courty <ncourty@irisa.fr> +# Rémi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import matplotlib as mpl +import torch + +from ot.lp import wasserstein_1d +from ot.datasets import make_1D_gauss as gauss +from ot.utils import proj_simplex + +red = np.array(mpl.colors.to_rgb('red')) +blue = np.array(mpl.colors.to_rgb('blue')) + + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# enforce sum to one on the support +a = a / a.sum() +b = b / b.sum() + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device).requires_grad_(True) +b_torch = torch.tensor(b).to(device=device) + +lr = 1e-6 +nb_iter_max = 800 + +loss_iter = [] + +pl.figure(1, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a_torch.grad + a_torch -= a_torch.grad * lr # step + a_torch.grad.zero_() + a_torch.data = proj_simplex(a_torch) # projection onto the simplex + + # plot one curve every 10 iterations + if i % 10 == 0: + mix = float(i) / nb_iter_max + pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + +pl.legend() +pl.title('Distribution along the iterations of the projected gradient descent') +pl.show() + +pl.figure(2) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() + +# %% +# Wasserstein barycenter +# --------- +# In this example, we consider the following Wasserstein barycenter problem +# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$ +# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t` +# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient +# descent method, where the gradient is computed by pyTorch automatic differentiation. +# The projection on the simplex ensures that the iterate will remain on the +# probability simplex. +# +# This example illustrates both `wasserstein_1d` function and backend use within the +# POT framework. + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device) +b_torch = torch.tensor(b).to(device=device) +bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True) + + +lr = 1e-6 +nb_iter_max = 2000 + +loss_iter = [] + +# instant of the interpolation +t = 0.5 + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = bary_torch.grad + bary_torch -= bary_torch.grad * lr # step + bary_torch.grad.zero_() + bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex + +pl.figure(3, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.legend() +pl.title('Wasserstein barycenter computed by gradient descent') +pl.show() + +pl.figure(4) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 27df656..7d73907 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -63,7 +63,7 @@ res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index 5bd4bab..f20332c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -42,7 +42,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -51,8 +51,9 @@ __version__ = "0.8.0dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', + 'emd2_1d', 'wasserstein_1d', 'backend', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py index 358297c..d3df44c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -103,6 +103,8 @@ class Backend(): __name__ = None __type__ = None + rng_ = None + def __str__(self): return self.__name__ @@ -540,6 +542,36 @@ class Backend(): """ raise NotImplementedError() + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. @@ -632,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + rng_ = np.random.RandomState() + def to_numpy(self, a): return a @@ -793,6 +827,16 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return coo_matrix((data, (rows, cols)), shape=shape) @@ -845,6 +889,11 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + def to_numpy(self, a): return np.array(a) @@ -1010,6 +1059,24 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): # Currently, JAX does not support sparse matrices data = self.to_numpy(data) @@ -1064,8 +1131,13 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + rng_ = None + def __init__(self): + self.rng_ = torch.Generator() + self.rng_.seed() + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1102,12 +1174,16 @@ class TorchBackend(Backend): return res def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.zeros(shape) else: return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.ones(shape) else: @@ -1120,6 +1196,8 @@ class TorchBackend(Backend): return torch.arange(start, stop, step, device=type_as.device) def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.full(shape, fill_value) else: @@ -1293,6 +1371,26 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4e95ccf..2c18a88 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -13,20 +13,23 @@ import multiprocessing import sys import numpy as np -from scipy.sparse import coo_matrix import warnings from . import cvx from .cvx import barycenter + # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d + from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. @@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): .. warning:: This function is necessary because the C++ solver in emd_c - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. + on the samples with weights different from zero. First we compute the constraints violations: @@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) array-like, float + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. @@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): Returns ------- - gamma: array-like, shape (ns, nt) + gamma: array-like, shape (ns, nt) Optimal transportation matrix for the given - parameters + parameters log: dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status @@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): regularized OT""" # convert to numpy if list - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass - np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b = b * a.sum() / b.sum() asel = a != 0 bsel = b != 0 @@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1, ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1, a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order= 'C') + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1, log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (log['u'], log['v'], G)) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1, G = nx.from_numpy(G, type_as=M0) cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0),G)) + (a0, b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0), G)) check_result(result_code) return cost @@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X - - -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the OT matrix - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. - - Returns - ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost - - - Examples - -------- - - Simple example with obvious solution. The function emd_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.emd_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd : EMD for multidimensional distributions - ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the - transportation matrix) - """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) - nx = get_backend(x_a, x_b) - - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - - # if empty array given then use uniform distributions - if a is None or a.ndim == 0 or len(a) == 0: - a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] - if b is None or b.ndim == 0 or len(b) == 0: - b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] - - # ensure that same mass - np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), - err_msg='a and b vector must have the same sum' - ) - b = b * nx.sum(a) / nx.sum(b) - - x_a_1d = nx.reshape(x_a, (-1,)) - x_b_1d = nx.reshape(x_b, (-1,)) - perm_a = nx.argsort(x_a_1d) - perm_b = nx.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), - metric=metric, p=p - ) - - G = nx.coo_matrix( - G_sorted, - perm_a[indices[:, 0]], - perm_b[indices[:, 1]], - shape=(a.shape[0], b.shape[0]), - type_as=x_a - ) - if dense: - G = nx.todense(G) - elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to dense") - if log: - log = {'cost': cost} - return G, log - return G - - -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function emd2_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd2_1d(x_a, x_b, a, b) - 0.5 - >>> ot.emd2_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd2 : EMD for multidimensional distributions - ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix - instead of the cost) - """ - # If we do not return G (log==False), then we should not to cast it to dense - # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] - if log: - log_emd = {'G': G} - return cost, log_emd - return cost - - -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): - r"""Solves the p-Wasserstein distance problem between 1d measures and returns - the distance - - .. math:: - \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - - Returns - ------- - dist: float - p-Wasserstein distance - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein_1d(x_a, x_b, a, b) - 0.5 - >>> ot.wasserstein_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - """ - cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=False, log=False) - return np.power(cost_emd, 1. / p) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py new file mode 100644 index 0000000..42554aa --- /dev/null +++ b/ot/lp/solver_1d.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Author: Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import warnings + +from .emd_wrap import emd_1d_sorted +from ..backend import get_backend +from ..utils import list_to_array + + +def quantile_function(qs, cws, xs): + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: array-like, shape (n,) + Quantiles at which the quantile function is evaluated + cws: array-like, shape (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: array-like, shape (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: array-like, shape (..., n) + The quantiles of the distribution + """ + nx = get_backend(qs, cws) + n = xs.shape[0] + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + cws = cws.T.contiguous() + qs = qs.T.contiguous() + else: + cws = cws.T + qs = qs.T + idx = nx.searchsorted(cws, qs).T + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + +def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [15] between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + + This function should be preferred to `emd_1d` whenever the backend is + different to numpy, and when gradients over + either sample positions or weights are required. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: float/array-like, shape (...) + the batched EMD + + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. + + """ + + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) + + qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = nx.abs(u_quantiles - v_quantiles) + + if p == 1: + return nx.sum(delta * nx.abs(diff_quantiles), axis=0) + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] + + # ensure that same mass + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) + if dense: + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + if log: + log = {'cost': cost} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost diff --git a/ot/sliced.py b/ot/sliced.py index 4792576..d3dc3f2 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,61 +1,73 @@ """ -Sliced Wasserstein Distance. +Sliced OT Distances """ # Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> +# Rémi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License import numpy as np +from .backend import get_backend, NumpyBackend +from .utils import list_to_array -def get_random_projections(n_projections, d, seed=None): +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` Parameters ---------- - n_projections : int - number of samples requested d : int dimension of the space + n_projections : int + number of samples requested seed: int or RandomState, optional Seed used for numpy random number generator + backend: + Backend to ue for random generation Returns ------- - out: ndarray, shape (n_projections, d) + out: ndarray, shape (d, n_projections) The uniform unit vectors on the sphere Examples -------- >>> n_projections = 100 >>> d = 5 - >>> projs = get_random_projections(n_projections, d) - >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + >>> projs = get_random_projections(d, n_projections) + >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE True """ - if not isinstance(seed, np.random.RandomState): - random_state = np.random.RandomState(seed) + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + projections = seed.randn(d, n_projections) else: - random_state = seed + if seed is not None: + nx.seed(seed) + projections = nx.randn(d, n_projections, type_as=type_as) - projections = random_state.normal(0., 1., [n_projections, d]) - norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) - projections = projections / norm + projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True)) return projections -def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): r""" - Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance .. math:: - \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}} where : @@ -74,8 +86,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional - Seed used for numpy random number generator + Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. @@ -100,10 +116,18 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 """ - from .lp import emd2_1d + from .lp import wasserstein_1d - X_s = np.asanyarray(X_s) - X_t = np.asanyarray(X_t) + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) n = X_s.shape[0] m = X_t.shape[0] @@ -114,31 +138,120 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed X_t.shape[1])) if a is None: - a = np.full(n, 1 / n) + a = nx.full(n, 1 / n) if b is None: - b = np.full(m, 1 / m) + b = nx.full(m, 1 / m) d = X_s.shape[1] - projections = get_random_projections(n_projections, d, seed) + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - X_s_projections = np.dot(projections, X_s.T) - X_t_projections = np.dot(projections, X_t.T) + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) if log: - projected_emd = np.empty(n_projections) + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance + + .. math:: + \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in + \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\# + \mu, \theta_\# \nu)]^{\frac{1}{p}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) else: - projected_emd = None + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n) + if b is None: + b = nx.full(m, 1 / m) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) - res = 0. + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): - emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) - if projected_emd is not None: - projected_emd[i] = emd - res += emd + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) - res = (res / n_projections) ** 0.5 + res = nx.max(projected_emd) ** (1.0 / p) if log: return res, {"projections": projections, "projected_emds": projected_emd} return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py new file mode 100644 index 0000000..2c470c2 --- /dev/null +++ b/test/test_1d_solver.py @@ -0,0 +1,85 @@ +"""Tests for module 1d Wasserstein solver""" + +# Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.lp import wasserstein_1d + +from ot.backend import get_backend_list +from scipy.stats import wasserstein_distance + +backend_list = get_backend_list() + + +def test_emd_1d_emd2_1d_with_weights(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d(nx): + from scipy.stats import wasserstein_distance + + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + + # test 1 : wasserstein_1d should be close to scipy W_1 implementation + np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v)) + + # test 2 : wasserstein_1d should be close to one when only translating the support + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), + 1.) + + # test 3 : arrays test + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) diff --git a/test/test_backend.py b/test/test_backend.py index 0f11ace..1832b91 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -208,6 +208,11 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() nx.coo_matrix(M, M, M) with pytest.raises(NotImplementedError): nx.issparse(M) @@ -248,6 +253,7 @@ def test_func_backends(nx): Mb = nx.from_numpy(M) vb = nx.from_numpy(v) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -255,6 +261,7 @@ def test_func_backends(nx): sp_datab = nx.from_numpy(sp_data) A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -505,6 +512,35 @@ def test_func_backends(nx): assert np.allclose(a1, a2, atol=1e-7) +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + def test_gradients_backends(): rnd = np.random.RandomState(0) diff --git a/test/test_ot.py b/test/test_ot.py index 4dfc510..5bfde1d 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,11 +8,11 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch +from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d(): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() - - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) - - -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) - - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 diff --git a/test/test_sliced.py b/test/test_sliced.py index a07d975..0bd74ec 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -1,6 +1,7 @@ """Tests for module sliced""" # Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> # # License: MIT License @@ -14,7 +15,7 @@ from ot.sliced import get_random_projections def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) - np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) def test_sliced_same_dist(): @@ -48,12 +49,12 @@ def test_sliced_log(): y = rng.randn(n, 4) u = ot.utils.unif(n) - res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] - assert len(projections) == len(projected_emds) == 10 + assert projections.shape[1] == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0 @@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd(): res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected) + + +def test_max_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_max_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert res > 0. + + +def test_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.sliced_wasserstein_distance(x, y, projections=P) + + val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_max_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) + + val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 0650ce2..40f4e49 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,7 +109,7 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) - D4 = ot.dist(x, x, metric='minkowski', p=0.5) + D4 = ot.dist(x, x, metric='minkowski', p=2) assert D4[0, 1] == D4[1, 0] |