summaryrefslogtreecommitdiff
path: root/ot/sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/sliced.py')
-rw-r--r--ot/sliced.py187
1 files changed, 184 insertions, 3 deletions
diff --git a/ot/sliced.py b/ot/sliced.py
index cf2d3be..077ff0b 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -12,7 +12,8 @@ Sliced OT Distances
import numpy as np
from .backend import get_backend, NumpyBackend
-from .utils import list_to_array
+from .utils import list_to_array, get_coordinate_circle
+from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
@@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
if projections is None:
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+ else:
+ n_projections = projections.shape[1]
X_s_projections = nx.dot(X_s, projections)
X_t_projections = nx.dot(X_t, projections)
@@ -206,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -256,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
+
+
+def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
+ p=2, seed=None, log=False):
+ r"""
+ Compute the spherical sliced-Wasserstein discrepancy.
+
+ .. math::
+ SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}}
+
+ where:
+
+ - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ X_t: ndarray, shape (n_samples_b, dim)
+ Samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional (default=2)
+ Power p used for computing the spherical sliced Wasserstein
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_sphere returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> n_samples_a = 20
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
+ >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None and b is not None:
+ nx = get_backend(X_s, X_t, a, b)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n, d = X_s.shape
+ m, _ = X_t.shape
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+ if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("Xt is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
+
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+ Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m))
+
+ projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p)
+ res = nx.mean(projected_emd) ** (1 / p)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False):
+ r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
+
+ .. math::
+ SSW_2(\mu_n, \nu)
+
+ where
+
+ - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}`
+ - :math:`\nu=\mathrm{Unif}(S^1)`
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ ---------
+ >>> np.random.seed(42)
+ >>> x0 = np.random.randn(500,3)
+ >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
+ >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42)
+ >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
+ True
+
+ References:
+ -----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None:
+ nx = get_backend(X_s, a)
+ else:
+ nx = get_backend(X_s)
+
+ n, d = X_s.shape
+
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+
+ projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
+ res = nx.mean(projected_emd) ** (1 / 2)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res