From b853e6a9db3dfc1f23d8f0b5101ca82dee686ecb Mon Sep 17 00:00:00 2001 From: Titouan Vayer Date: Fri, 16 Dec 2022 16:26:10 +0100 Subject: [MRG] Change the number of projection to match the predefined case (#419) --- ot/sliced.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'ot') diff --git a/ot/sliced.py b/ot/sliced.py index cf2d3be..20891a4 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, if projections is None: projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + else: + n_projections = projections.shape[1] X_s_projections = nx.dot(X_s, projections) X_t_projections = nx.dot(X_t, projections) -- cgit v1.2.3