summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTitouan Vayer <titouan.vayer@gmail.com>2022-12-16 16:26:10 +0100
committerGitHub <noreply@github.com>2022-12-16 16:26:10 +0100
commitb853e6a9db3dfc1f23d8f0b5101ca82dee686ecb (patch)
tree90de4ab0856d7d20189aa805eed6a2b3453d18d2 /test
parent0411ea22a96f9c22af30156b45c16ef39ffb520d (diff)
[MRG] Change the number of projection to match the predefined case (#419)
Diffstat (limited to 'test')
-rw-r--r--test/test_sliced.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..eb13469 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100