diff options
Diffstat (limited to 'ot/sliced.py')
-rw-r--r-- | ot/sliced.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/ot/sliced.py b/ot/sliced.py index 3a1644d..fd86df9 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - p=2, seed=None, log=False): + p=2, projections=None, seed=None, log=False): r""" Compute the spherical sliced-Wasserstein discrepancy. @@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Number of projections used for the Monte-Carlo approximation p: float, optional (default=2) Power p used for computing the spherical sliced Wasserstein + projections: shape (n_projections, dim, 2), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional @@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): raise ValueError("X_s is not on the sphere.") if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): - raise ValueError("Xt is not on the sphere.") + raise ValueError("X_t is not on the sphere.") - # Uniforms and independent samples on the Stiefel manifold V_{d,2} - if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': - Z = seed.randn(n_projections, d, 2) + if projections is None: + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) else: - if seed is not None: - nx.seed(seed) - Z = nx.randn(n_projections, d, 2, type_as=X_s) - - projections, _ = nx.qr(Z) + n_projections = projections.shape[0] # Projection on S^1 # Projection on plane - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) - Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) + Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) + Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t) # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) @@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log # Projection on S^1 # Projection on plane - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xps = nx.einsum("ikj, lk -> ilj", projections, X_s) + # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coordinates on [0,1[ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) |