diff options
Diffstat (limited to 'ot/sliced.py')
-rw-r--r-- | ot/sliced.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/sliced.py b/ot/sliced.py index 7c09111..cf2d3be 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] @@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, X_t.shape[1])) if a is None: - a = nx.full(n, 1 / n) + a = nx.full(n, 1 / n, type_as=X_s) if b is None: - b = nx.full(m, 1 / m) + b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] |