summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
-rw-r--r--CONTRIBUTORS.md1
-rw-r--r--README.md10
-rw-r--r--RELEASES.md4
-rw-r--r--examples/backends/plot_ssw_unif_torch.py153
-rw-r--r--examples/plot_compute_wasserstein_circle.py161
-rw-r--r--examples/sliced-wasserstein/plot_variance_ssw.py111
-rw-r--r--ot/__init__.py13
-rw-r--r--ot/backend.py204
-rw-r--r--ot/lp/__init__.py7
-rw-r--r--ot/lp/solver_1d.py627
-rw-r--r--ot/sliced.py185
-rw-r--r--ot/utils.py30
-rw-r--r--test/test_1d_solver.py127
-rw-r--r--test/test_backend.py46
-rw-r--r--test/test_sliced.py186
-rw-r--r--test/test_utils.py10
16 files changed, 1852 insertions, 23 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 67d8337..1437821 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -41,6 +41,7 @@ The contributors to this library are:
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
+* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
## Acknowledgments
diff --git a/README.md b/README.md
index 7c9475b..d5e6854 100644
--- a/README.md
+++ b/README.md
@@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
+* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
+* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
@@ -292,4 +294,10 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.
-[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file
+[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+
+[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index 4ed3625..f8ef653 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -4,6 +4,10 @@
#### New features
+- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434)
+- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434)
+- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434)
+- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434)
- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py
new file mode 100644
index 0000000..d1de5a9
--- /dev/null
+++ b/examples/backends/plot_ssw_unif_torch.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+r"""
+================================================
+Spherical Sliced-Wasserstein Embedding on Sphere
+================================================
+
+Here, we aim at transforming samples into a uniform
+distribution on the sphere by minimizing SSW:
+
+.. math::
+ \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
+
+where :math:`\nu=\mathrm{Unif}(S^1)`.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 3
+
+import numpy as np
+import matplotlib.pyplot as pl
+import matplotlib.animation as animation
+import torch
+import torch.nn.functional as F
+
+import ot
+
+
+# %%
+# Data generation
+# ---------------
+
+torch.manual_seed(1)
+
+N = 1000
+x0 = torch.rand(N, 3)
+x0 = F.normalize(x0, dim=-1)
+
+
+# %%
+# Plot data
+# ---------
+
+def plot_sphere(ax):
+ xlist = np.linspace(-1.0, 1.0, 50)
+ ylist = np.linspace(-1.0, 1.0, 50)
+ r = np.linspace(1.0, 1.0, 50)
+ X, Y = np.meshgrid(xlist, ylist)
+
+ Z = np.sqrt(r**2 - X**2 - Y**2)
+
+ ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3)
+ ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half
+
+
+# plot the distributions
+pl.figure(1)
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5)
+ax.set_title('Data distribution')
+ax.legend()
+
+
+# %%
+# Gradient descent
+# ----------------
+
+x = x0.clone()
+x.requires_grad_(True)
+
+n_iter = 500
+lr = 100
+
+losses = []
+xvisu = torch.zeros(n_iter, N, 3)
+
+for i in range(n_iter):
+ sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
+ grad_x = torch.autograd.grad(sw, x)[0]
+
+ x = x - lr * grad_x
+ x = F.normalize(x, p=2, dim=1)
+
+ losses.append(sw.item())
+ xvisu[i, :, :] = x.detach().clone()
+
+ if i % 100 == 0:
+ print("Iter: {:3d}, loss={}".format(i, losses[-1]))
+
+pl.figure(1)
+pl.semilogy(losses)
+pl.grid()
+pl.title('SSW')
+pl.xlabel("Iterations")
+
+
+# %%
+# Plot trajectories of generated samples along iterations
+# -------------------------------------------------------
+
+ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
+
+fig = pl.figure(3, (10, 10))
+for i in range(9):
+ # pl.subplot(3, 3, i + 1)
+ # ax = pl.axes(projection='3d')
+ ax = fig.add_subplot(3, 3, i + 1, projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5)
+ ax.set_title('Iter. {}'.format(ivisu[i]))
+ #ax.axis("off")
+ if i == 0:
+ ax.legend()
+
+
+# %%
+# Animate trajectories of generated samples along iteration
+# -------------------------------------------------------
+
+pl.figure(4, (8, 8))
+
+
+def _update_plot(i):
+ i = 3 * i
+ pl.clf()
+ ax = pl.axes(projection='3d')
+ plot_sphere(ax)
+ ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5)
+ ax.axis("off")
+ ax.set_xlim((-1.5, 1.5))
+ ax.set_ylim((-1.5, 1.5))
+ ax.set_title('Iter. {}'.format(i))
+ return 1
+
+
+print(xvisu.shape)
+
+i = 0
+ax = pl.axes(projection='3d')
+plot_sphere(ax)
+ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5)
+ax.axis("off")
+ax.set_xlim((-1.5, 1.5))
+ax.set_ylim((-1.5, 1.5))
+ax.set_title('Iter. {}'.format(ivisu[i]))
+
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
+# %%
diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py
new file mode 100644
index 0000000..3ede96f
--- /dev/null
+++ b/examples/plot_compute_wasserstein_circle.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+"""
+=========================
+OT distance on the Circle
+=========================
+
+Shows how to compute the Wasserstein distance on the circle
+
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+from scipy.special import iv
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+
+def pdf_von_Mises(theta, mu, kappa):
+ pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa))
+ return pdf
+
+
+t = np.linspace(0, 2 * np.pi, 1000, endpoint=False)
+
+mu1 = 1
+kappa1 = 20
+
+mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10)
+
+
+pdf1 = pdf_von_Mises(t, mu1, kappa1)
+
+
+pl.figure(1)
+for k, mu in enumerate(mu_targets):
+ pdf_t = pdf_von_Mises(t, mu, kappa1)
+ if k == 0:
+ label = "Source distributions"
+ else:
+ label = None
+ pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label)
+
+pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution")
+pl.legend()
+
+mu2 = 0
+kappa2 = kappa1
+
+x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi
+x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi
+
+angles = np.linspace(0, 2 * np.pi, 150)
+
+pl.figure(2)
+pl.plot(np.cos(angles), np.sin(angles), c="k")
+pl.xlim(-1.25, 1.25)
+pl.ylim(-1.25, 1.25)
+pl.scatter(np.cos(x1), np.sin(x1), c="b")
+pl.scatter(np.cos(x2), np.sin(x2), c="r")
+
+#########################################################################################
+# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle
+# ---------------------------------------------------------------------------------------
+# This examples illustrates the periodicity of the Wasserstein distance on the circle.
+# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}`
+# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution
+# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`.
+# The Wasserstein distance on the circle takes into account the periodicity
+# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the
+# Euclidean version.
+
+#%% Compute and plot distributions
+
+mu_targets = np.linspace(0, 2 * np.pi, 200)
+xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi
+
+n_try = 5
+
+xts = np.zeros((n_try, 200, 500))
+for i in range(n_try):
+ for k, mu in enumerate(mu_targets):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi
+ xts[i, k] = xt
+
+# Put data on S^1=[0,1[
+xts2 = xts / (2 * np.pi)
+xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi)
+
+L_w2_circle = np.zeros((n_try, 200))
+L_w2 = np.zeros((n_try, 200))
+
+for i in range(n_try):
+ w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
+ w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
+
+ L_w2_circle[i] = w2_circle
+ L_w2[i] = w2
+
+m_w2_circle = np.mean(L_w2_circle, axis=0)
+std_w2_circle = np.std(L_w2_circle, axis=0)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5)
+pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein")
+pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5)
+pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$")
+pl.legend()
+pl.xlabel(r"$\mu_{\mathrm{source}}$")
+pl.show()
+
+
+########################################################################
+# Wasserstein distance between von Mises and uniform for different kappa
+# ----------------------------------------------------------------------
+# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`.
+
+#%% Compute Wasserstein between Von Mises and uniform
+
+kappas = np.logspace(-5, 2, 100)
+n_try = 20
+
+xts = np.zeros((n_try, 100, 500))
+for i in range(n_try):
+ for k, kappa in enumerate(kappas):
+ # np.random.vonmises deals with data on [-pi, pi[
+ xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi
+ xts[i, k] = xt / (2 * np.pi)
+
+L_w2 = np.zeros((n_try, 100))
+for i in range(n_try):
+ L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
+
+m_w2 = np.mean(L_w2, axis=0)
+std_w2 = np.std(L_w2, axis=0)
+
+pl.figure(1)
+pl.plot(kappas, m_w2)
+pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
+pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
+pl.xlabel(r"$\kappa$")
+pl.show()
+
+# %%
diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py
new file mode 100644
index 0000000..83d458f
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance_ssw.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Spherical Sliced Wasserstein on distributions in S^2
+====================================================
+
+This example illustrates the computation of the spherical sliced Wasserstein discrepancy as
+proposed in [46].
+
+[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations.
+
+"""
+
+# Author: Clément Bonet <clement.bonet@univ-ubs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+xs = np.random.randn(n, 3)
+xt = np.random.randn(n, 3)
+
+xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True))
+xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True))
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+fig = pl.figure(figsize=(10, 10))
+ax = pl.axes(projection='3d')
+ax.grid(False)
+
+u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j]
+x = np.cos(u) * np.sin(v)
+y = np.sin(u) * np.sin(v)
+z = np.cos(v)
+ax.plot_surface(x, y, z, color="gray", alpha=0.03)
+ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray")
+
+ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source")
+ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target")
+
+fs = 10
+# Labels
+ax.set_xlabel('x', fontsize=fs)
+ax.set_ylabel('y', fontsize=fs)
+ax.set_zlabel('z', fontsize=fs)
+
+ax.view_init(20, 120)
+ax.set_xlim(-1.5, 1.5)
+ax.set_ylim(-1.5, 1.5)
+ax.set_zlim(-1.5, 1.5)
+
+# Ticks
+ax.set_xticks([-1, 0, 1])
+ax.set_yticks([-1, 0, 1])
+ax.set_zticks([-1, 0, 1])
+
+pl.legend(loc=0)
+pl.title("Source and Target distribution")
+
+###############################################################################
+# Spherical Sliced Wasserstein for different seeds and number of projections
+# --------------------------------------------------------------------------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###############################################################################
+# Plot Spherical Sliced Wasserstein
+# ---------------------------------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 0b55e0c..45d5cfa 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -38,12 +38,15 @@ from . import solvers
from . import gaussian
# OT functions
-from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
+from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
-from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
+ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
@@ -60,8 +63,10 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport', 'solve',
- 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers']
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
+ 'binary_search_circle', 'wasserstein_circle',
+ 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
diff --git a/ot/backend.py b/ot/backend.py
index 337e040..0779243 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -534,9 +534,9 @@ class Backend():
"""
raise NotImplementedError()
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
r"""
- Pads a tensor.
+ Pads a tensor with a given value (0 by default).
This function follows the api from :any:`numpy.pad`
@@ -895,6 +895,62 @@ class Backend():
"""
raise NotImplementedError()
+ def tile(self, a, reps):
+ r"""
+ Construct an array by repeating a the number of times given by reps
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html
+ """
+ raise NotImplementedError()
+
+ def floor(self, a):
+ r"""
+ Return the floor of the input element-wise
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html
+ """
+ raise NotImplementedError()
+
+ def prod(self, a, axis=None):
+ r"""
+ Return the product of all elements.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html
+ """
+ raise NotImplementedError()
+
+ def sort2(self, a, axis=None):
+ r"""
+ Return the sorted array and the indices to sort the array
+
+ See: https://pytorch.org/docs/stable/generated/torch.sort.html
+ """
+ raise NotImplementedError()
+
+ def qr(self, a):
+ r"""
+ Return the QR factorization
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html
+ """
+ raise NotImplementedError()
+
+ def atan2(self, a, b):
+ r"""
+ Element wise arctangent
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html
+ """
+ raise NotImplementedError()
+
+ def transpose(self, a, axes=None):
+ r"""
+ Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1039,8 +1095,8 @@ class NumpyBackend(Backend):
def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return np.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return np.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
@@ -1185,6 +1241,44 @@ class NumpyBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return np.tile(a, reps)
+
+ def floor(self, a):
+ return np.floor(a)
+
+ def prod(self, a, axis=0):
+ return np.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ np_version = tuple([int(k) for k in np.__version__.split(".")])
+ if np_version < (1, 22, 0):
+ M, N = a.shape[-2], a.shape[-1]
+ K = min(M, N)
+
+ if len(a.shape) >= 3:
+ n = a.shape[0]
+
+ qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N))
+
+ for i in range(a.shape[0]):
+ qs[i], rs[i] = np.linalg.qr(a[i])
+
+ else:
+ return np.linalg.qr(a)
+
+ return qs, rs
+ return np.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return np.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return np.transpose(a, axes)
+
class JaxBackend(Backend):
"""
@@ -1351,8 +1445,8 @@ class JaxBackend(Backend):
def concatenate(self, arrays, axis=0):
return jnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return jnp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return jnp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
@@ -1511,6 +1605,27 @@ class JaxBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return jnp.tile(a, reps)
+
+ def floor(self, a):
+ return jnp.floor(a)
+
+ def prod(self, a, axis=0):
+ return jnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return jnp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return jnp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return jnp.transpose(a, axes)
+
class TorchBackend(Backend):
"""
@@ -1729,13 +1844,13 @@ class TorchBackend(Backend):
def concatenate(self, arrays, axis=0):
return torch.cat(arrays, dim=axis)
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad
# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
- return pad(a, how_pad)
+ return pad(a, how_pad, value=value)
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
@@ -1934,6 +2049,29 @@ class TorchBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating_point
+ def tile(self, a, reps):
+ return a.repeat(reps)
+
+ def floor(self, a):
+ return torch.floor(a)
+
+ def prod(self, a, axis=0):
+ return torch.prod(a, dim=axis)
+
+ def sort2(self, a, axis=-1):
+ return torch.sort(a, axis)
+
+ def qr(self, a):
+ return torch.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return torch.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ if axes is None:
+ axes = tuple(range(a.ndim)[::-1])
+ return a.permute(axes)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2096,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover
def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return cp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return cp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
@@ -2284,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return cp.tile(a, reps)
+
+ def floor(self, a):
+ return cp.floor(a)
+
+ def prod(self, a, axis=0):
+ return cp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return cp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return cp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return cp.transpose(a, axes)
+
class TensorflowBackend(Backend):
@@ -2454,8 +2613,8 @@ class TensorflowBackend(Backend):
def concatenate(self, arrays, axis=0):
return tnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return tnp.pad(a, pad_width, mode="constant")
+ def zero_pad(self, a, pad_width, value=0):
+ return tnp.pad(a, pad_width, mode="constant", constant_values=value)
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
@@ -2646,3 +2805,24 @@ class TensorflowBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating
+
+ def tile(self, a, reps):
+ return tnp.tile(a, reps)
+
+ def floor(self, a):
+ return tf.floor(a)
+
+ def prod(self, a, axis=0):
+ return tnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return tf.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return tf.math.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return tf.transpose(a, perm=axes)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 17411d0..7d0640f 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -20,14 +20,17 @@ from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
+ 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
def check_number_threads(numThreads):
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 43763a9..e7add89 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
distributions
.. math:
- OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq
It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
@@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log_emd = {'G': G}
return cost, log_emd
return cost
+
+
+def roll_cols(M, shifts):
+ r"""
+ Utils functions which allow to shift the order of each row of a 2d matrix
+
+ Parameters
+ ----------
+ M : (nr, nc) ndarray
+ Matrix to shift
+ shifts: int or (nr,) ndarray
+
+ Returns
+ -------
+ Shifted array
+
+ Examples
+ --------
+ >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
+ >>> roll_cols(M, 2)
+ array([[2, 3, 1],
+ [5, 6, 4],
+ [8, 9, 7]])
+ >>> roll_cols(M, np.array([[1],[2],[1]]))
+ array([[3, 1, 2],
+ [5, 6, 4],
+ [9, 7, 8]])
+
+ References
+ ----------
+ https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
+ """
+ nx = get_backend(M)
+
+ n_rows, n_cols = M.shape
+
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange2 = (arange1 - shifts) % n_cols
+
+ return nx.take_along_axis(M, arange2, 1)
+
+
+def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
+ r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ dCp: array-like, shape (n_batch, 1)
+ The batched right derivative
+ dCm: array-like, shape (n_batch, 1)
+ The batched left derivative
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ n = u_values.shape[-1]
+ m_batch, m = v_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ # quantiles of F_u evaluated in F_v^\theta
+ u_index = nx.searchsorted(u_cdf, v_cdf_theta)
+ u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
+
+ # Deal with 1
+ u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
+ u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdfm = u_cdfm.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
+ u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
+
+ dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)
+
+ dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)
+
+ return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
+
+
+def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
+ r""" Computes the the cost (Equation (6.2) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ ot_cost: array-like, shape (n_batch,)
+ OT cost evaluated at theta
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ m_batch, m = v_values.shape
+ n_batch, n = u_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ # Put negative values at the end
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ # Compute absciss
+ cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
+ cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
+
+ delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+ cdf_axis = cdf_axis.contiguous()
+
+ # Compute icdf
+ u_index = nx.searchsorted(u_cdf, cdf_axis)
+ u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
+
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+ v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
+ v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
+
+ if p == 1:
+ ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
+ else:
+ ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
+
+ return ot_cost
+
+
+def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
+ log=False):
+ r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ where:
+
+ - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC
+ Lp : int, optional
+ Upper bound dC
+ tm: float, optional
+ Lower bound theta
+ tp: float, optional
+ Upper bound theta
+ eps: float, optional
+ Stopping condition
+ require_sort: bool, optional
+ If True, sort the values.
+ log: bool, optional
+ If True, returns also the optimal theta
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict, optional
+ log dictionary returned only if log==True in parameters
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> binary_search_circle(u.T, v.T, p=1)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cdf = nx.cumsum(u_weights, 0).T
+ v_cdf = nx.cumsum(v_weights, 0).T
+
+ u_values = u_values.T
+ v_values = v_values.T
+
+ L = max(Lm, Lp)
+
+ tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tm = nx.tile(tm, (1, m))
+ tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tp = nx.tile(tp, (1, m))
+ tc = (tm + tp) / 2
+
+ done = nx.zeros((u_values.shape[0], m))
+
+ cpt = 0
+ while nx.any(1 - done):
+ cpt += 1
+
+ dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ done = ((dCp * dCm) <= 0) * 1
+
+ mask = ((tp - tm) < eps / L) * (1 - done)
+
+ if nx.any(mask):
+ # can probably be improved by computing only relevant values
+ dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
+ dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
+ Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+ Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+
+ mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
+ tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
+ done[nx.prod(mask, axis=-1) > 0] = 1
+ elif nx.any(1 - done):
+ tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
+ tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
+ tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
+
+ w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+
+ if log:
+ return w, {"optimal_theta": tc[:, 0]}
+ return w
+
+
+def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
+ r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
+ using e.g. the atan2 function.
+ The function runs on backend but tensorflow is not supported.
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein1_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ """
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
+
+ cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
+ cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
+
+ values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
+ delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
+ weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
+
+ sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
+ sum_weights[sum_weights < 0] = np.inf
+ inds = nx.argmin(sum_weights, axis=0)
+
+ levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
+
+ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
+
+
+def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
+ r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
+ the binary search algorithm proposed in [44] otherwise.
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
+ using e.g. the atan2 function.
+
+ General loss returned:
+
+ .. math::
+ OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ For p=1, [45]
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC. For p>1.
+ Lp : int, optional
+ Upper bound dC. For p>1.
+ tm: float, optional
+ Lower bound theta. For p>1.
+ tp: float, optional
+ Upper bound theta. For p>1.
+ eps: float, optional
+ Stopping condition. For p>1.
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if p == 1:
+ return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)
+
+ return binary_search_circle(u_values, v_values, u_weights, v_weights,
+ p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
+ require_sort=require_sort)
+
+
+def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
+ r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
+
+ where:
+
+ - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ Parameters
+ ----------
+ u_values: ndarray, shape (n, ...)
+ Samples
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> x0 = np.array([[0], [0.2], [0.4]])
+ >>> semidiscrete_wasserstein2_unif_circle(x0)
+ array([0.02111111])
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+
+ if u_weights is not None:
+ nx = get_backend(u_values, u_weights)
+ else:
+ nx = get_backend(u_values)
+
+ n = u_values.shape[0]
+
+ u_values = u_values % 1
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+
+ u_values = nx.sort(u_values, 0)
+ u_cdf = nx.cumsum(u_weights, 0)
+ u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
+
+ cpt1 = nx.sum(u_weights * u_values**2, axis=0)
+ u_mean = nx.sum(u_weights * u_values, axis=0)
+
+ ns = 1 - u_weights - 2 * u_cdf[:-1]
+ cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
+
+ return cpt1 - u_mean**2 + cpt2 + 1 / 12
diff --git a/ot/sliced.py b/ot/sliced.py
index 20891a4..077ff0b 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -12,7 +12,8 @@ Sliced OT Distances
import numpy as np
from .backend import get_backend, NumpyBackend
-from .utils import list_to_array
+from .utils import list_to_array, get_coordinate_circle
+from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
@@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -208,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -258,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
+
+
+def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
+ p=2, seed=None, log=False):
+ r"""
+ Compute the spherical sliced-Wasserstein discrepancy.
+
+ .. math::
+ SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}}
+
+ where:
+
+ - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ X_t: ndarray, shape (n_samples_b, dim)
+ Samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional (default=2)
+ Power p used for computing the spherical sliced Wasserstein
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_sphere returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> n_samples_a = 20
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
+ >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None and b is not None:
+ nx = get_backend(X_s, X_t, a, b)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n, d = X_s.shape
+ m, _ = X_t.shape
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+ if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("Xt is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
+
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+ Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m))
+
+ projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p)
+ res = nx.mean(projected_emd) ** (1 / p)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False):
+ r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
+
+ .. math::
+ SSW_2(\mu_n, \nu)
+
+ where
+
+ - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}`
+ - :math:`\nu=\mathrm{Unif}(S^1)`
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ ---------
+ >>> np.random.seed(42)
+ >>> x0 = np.random.randn(500,3)
+ >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
+ >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42)
+ >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
+ True
+
+ References:
+ -----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None:
+ nx = get_backend(X_s, a)
+ else:
+ nx = get_backend(X_s)
+
+ n, d = X_s.shape
+
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+
+ projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
+ res = nx.mean(projected_emd) ** (1 / 2)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/ot/utils.py b/ot/utils.py
index 9093f09..3423a7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -375,6 +375,36 @@ def check_random_state(seed):
' instance'.format(seed))
+def get_coordinate_circle(x):
+ r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
+ turn (in [0,1[).
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ Parameters
+ ----------
+ x: ndarray, shape (n, 2)
+ Samples on the circle with ambient coordinates
+
+ Returns
+ -------
+ x_t: ndarray, shape (n,)
+ Coordinates on [0,1[
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
+ >>> x1, y1 = np.cos(u), np.sin(u)
+ >>> x = np.concatenate([x1, y1]).T
+ >>> get_coordinate_circle(x)
+ array([0.2, 0.5, 0.8])
+ """
+ nx = get_backend(x)
+ x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ return x_t
+
+
class deprecated(object):
r"""Decorator to mark a function or class as deprecated.
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 20f307a..21abd1d 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -218,3 +218,130 @@ def test_emd1d_device_tf():
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
assert nx.dtype_device(emd)[1].startswith("GPU")
+
+
+def test_wasserstein_1d_circle():
+ # test binary_search_circle and wasserstein_circle give similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+
+ wass1 = ot.emd2(w_u, w_v, M1)
+
+ wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1)
+ w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1)
+
+ M2 = M1**2
+ wass2 = ot.emd2(w_u, w_v, M2)
+ wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2)
+ w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass1, wass1_bsc)
+ np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2)
+ np.testing.assert_allclose(wass2, wass2_bsc)
+ np.testing.assert_allclose(wass2, w2_circle)
+
+
+@pytest.skip_backend("tf")
+def test_wasserstein1d_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
+
+ w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1)
+ w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2)
+
+ nx.assert_same_dtype_device(xb, w1)
+ nx.assert_same_dtype_device(xb, w2_bsc)
+
+
+def test_wasserstein_1d_unif_circle():
+ # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
+ n = 20
+ m = 50000
+
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ # w_u = rng.uniform(0., 1., n)
+ # w_u = w_u / w_u.sum()
+
+ w_u = ot.utils.unif(n)
+ w_v = ot.utils.unif(m)
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+ wass2 = ot.emd2(w_u, w_v, M1**2)
+
+ wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15)
+ wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+
+
+def test_wasserstein1d_unif_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp)
+
+ w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub)
+
+ nx.assert_same_dtype_device(xb, w2)
+
+
+def test_binary_search_circle_log():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True)
+ optimal_thetas = log["optimal_theta"]
+
+ assert optimal_thetas.shape[0] == 1
+
+
+def test_wasserstein_circle_bad_shape():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n, 2)
+ v = rng.rand(m, 1)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=2)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=1)
diff --git a/test/test_backend.py b/test/test_backend.py
index 3628f61..fd9a761 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -282,6 +282,20 @@ def test_empty_backend():
nx.array_equal(M, M)
with pytest.raises(NotImplementedError):
nx.is_floating_point(M)
+ with pytest.raises(NotImplementedError):
+ nx.tile(M, (10, 1))
+ with pytest.raises(NotImplementedError):
+ nx.floor(M)
+ with pytest.raises(NotImplementedError):
+ nx.prod(M)
+ with pytest.raises(NotImplementedError):
+ nx.sort2(M)
+ with pytest.raises(NotImplementedError):
+ nx.qr(M)
+ with pytest.raises(NotImplementedError):
+ nx.atan2(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.transpose(M)
def test_func_backends(nx):
@@ -603,6 +617,38 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("isfinite")
+ A = nx.tile(vb, (10, 1))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("tile")
+
+ A = nx.floor(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("floor")
+
+ A = nx.prod(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("prod")
+
+ A, B = nx.sort2(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("sort2 sort")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("sort2 argsort")
+
+ A, B = nx.qr(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("QR Q")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("QR R")
+
+ A = nx.atan2(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("atan2")
+
+ A = nx.transpose(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("transpose")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_sliced.py b/test/test_sliced.py
index eb13469..f54c799 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf():
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
+def test_projections_stiefel():
+ rng = np.random.RandomState(0)
+
+ n_projs = 500
+ x = np.random.randn(100, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs,
+ seed=rng, log=True)
+
+ P = log["projections"]
+ P_T = np.transpose(P, [0, 2, 1])
+ np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]))
+
+
+def test_sliced_sphere_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_sphere_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_sphere_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+
+ y = rng.randn(m, 2)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
+ u = ot.utils.unif(m)
+
+ res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2)
+ expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2)
+
+ res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1)
+ expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1)
+
+ np.testing.assert_almost_equal(res ** 2, expected)
+ np.testing.assert_almost_equal(res1, expected1, decimal=3)
+
+
+@pytest.skip_backend("tf")
+def test_sliced_sphere_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(2 * n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, yb = nx.from_numpy(x, y, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere(xb, yb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_sliced_sphere_unif_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng)
+
+
+def test_sliced_sphere_unif_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_unif_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere_unif(xb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_utils.py b/test/test_utils.py
index 666c157..31b12ef 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -330,3 +330,13 @@ def test_OTResult():
for at in lst_attributes:
with pytest.raises(NotImplementedError):
getattr(res, at)
+
+
+def test_get_coordinate_circle():
+
+ u = np.random.rand(1, 100)
+ x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
+ x = np.concatenate([x1, y1]).T
+ x_p = ot.utils.get_coordinate_circle(x)
+
+ np.testing.assert_allclose(u[0], x_p)