summaryrefslogtreecommitdiff
path: root/ot/sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/sliced.py')
-rw-r--r--ot/sliced.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/ot/sliced.py b/ot/sliced.py
index 3a1644d..fd86df9 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
- p=2, seed=None, log=False):
+ p=2, projections=None, seed=None, log=False):
r"""
Compute the spherical sliced-Wasserstein discrepancy.
@@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
Number of projections used for the Monte-Carlo approximation
p: float, optional (default=2)
Power p used for computing the spherical sliced Wasserstein
+ projections: shape (n_projections, dim, 2), optional
+ Projection matrix (n_projections and seed are not used in this case)
seed: int or RandomState or None, optional
Seed used for random number generator
log: bool, optional
@@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
raise ValueError("X_s is not on the sphere.")
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
- raise ValueError("Xt is not on the sphere.")
+ raise ValueError("X_t is not on the sphere.")
- # Uniforms and independent samples on the Stiefel manifold V_{d,2}
- if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
- Z = seed.randn(n_projections, d, 2)
+ if projections is None:
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
else:
- if seed is not None:
- nx.seed(seed)
- Z = nx.randn(n_projections, d, 2, type_as=X_s)
-
- projections, _ = nx.qr(Z)
+ n_projections = projections.shape[0]
# Projection on S^1
# Projection on plane
- Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
- Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+ Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
+ Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t)
# Projection on sphere
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
@@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
# Projection on S^1
# Projection on plane
- Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
+
# Projection on sphere
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+
# Get coordinates on [0,1[
Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))