summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorTitouan Vayer <titouan.vayer@gmail.com>2022-12-16 16:26:10 +0100
committerGitHub <noreply@github.com>2022-12-16 16:26:10 +0100
commitb853e6a9db3dfc1f23d8f0b5101ca82dee686ecb (patch)
tree90de4ab0856d7d20189aa805eed6a2b3453d18d2 /ot
parent0411ea22a96f9c22af30156b45c16ef39ffb520d (diff)
[MRG] Change the number of projection to match the predefined case (#419)
Diffstat (limited to 'ot')
-rw-r--r--ot/sliced.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/ot/sliced.py b/ot/sliced.py
index cf2d3be..20891a4 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
if projections is None:
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+ else:
+ n_projections = projections.shape[1]
X_s_projections = nx.dot(X_s, projections)
X_t_projections = nx.dot(X_t, projections)