diff options
Diffstat (limited to 'ot/sliced.py')
-rw-r--r-- | ot/sliced.py | 187 |
1 files changed, 184 insertions, 3 deletions
diff --git a/ot/sliced.py b/ot/sliced.py index cf2d3be..077ff0b 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,8 @@ Sliced OT Distances import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array +from .utils import list_to_array, get_coordinate_circle +from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, if projections is None: projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + else: + n_projections = projections.shape[1] X_s_projections = nx.dot(X_s, projections) X_t_projections = nx.dot(X_t, projections) @@ -206,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -256,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, + p=2, seed=None, log=False): + r""" + Compute the spherical sliced-Wasserstein discrepancy. + + .. math:: + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} + + where: + + - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + The function runs on backend but tensorflow is not supported. + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim) + Samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional (default=2) + Power p used for computing the spherical sliced Wasserstein + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_sphere returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + n, d = X_s.shape + m, _ = X_t.shape + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("Xt is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) + + projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1 / p) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): + r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. + + .. math:: + SSW_2(\mu_n, \nu) + + where + + - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` + - :math:`\nu=\mathrm{Unif}(S^1)` + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> np.random.seed(42) + >>> x0 = np.random.randn(500,3) + >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) + >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42) + >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3) + True + + References: + ----------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + if a is not None: + nx = get_backend(X_s, a) + else: + nx = get_backend(X_s) + + n, d = X_s.shape + + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + + projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + res = nx.mean(projected_emd) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res |