summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..eb13469 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100