summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTitouan Vayer <titouan.vayer@gmail.com>2022-12-16 16:26:10 +0100
committerGitHub <noreply@github.com>2022-12-16 16:26:10 +0100
commitb853e6a9db3dfc1f23d8f0b5101ca82dee686ecb (patch)
tree90de4ab0856d7d20189aa805eed6a2b3453d18d2
parent0411ea22a96f9c22af30156b45c16ef39ffb520d (diff)
[MRG] Change the number of projection to match the predefined case (#419)
-rw-r--r--ot/sliced.py2
-rw-r--r--test/test_sliced.py14
2 files changed, 16 insertions, 0 deletions
diff --git a/ot/sliced.py b/ot/sliced.py
index cf2d3be..20891a4 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
if projections is None:
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+ else:
+ n_projections = projections.shape[1]
X_s_projections = nx.dot(X_s, projections)
X_t_projections = nx.dot(X_t, projections)
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..eb13469 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100