summaryrefslogtreecommitdiff
path: root/ot/sliced.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-05 15:57:08 +0100
committerGitHub <noreply@github.com>2021-11-05 15:57:08 +0100
commit0eac835c70cc1a13bb998f3b6cdb0515fafc05e1 (patch)
treeb0c0fbce0109ba460a67a6356dc0ff03e2b3c1d5 /ot/sliced.py
parent0e431c203a66c6d48e6bb1efeda149460472a0f0 (diff)
[MRG] Tests with types/device on sliced/bregman/gromov functions (#303)
* First draft : making pytest use gpu for torch testing * bug solve * Revert "bug solve" This reverts commit 29b013abd162f8693128f26d8129186b79923609. * Revert "First draft : making pytest use gpu for torch testing" This reverts commit 2778175bcc338016c704efa4187d132fe5162e3a. * sliced * sliced * ot 1dsolver * bregman * better print * jax works with sinkhorn, sinkhorn_log and sinkhornn_stabilized, no need to skip them * gromov & entropic gromov
Diffstat (limited to 'ot/sliced.py')
-rw-r--r--ot/sliced.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/sliced.py b/ot/sliced.py
index 7c09111..cf2d3be 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -139,9 +139,9 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
X_t.shape[1]))
if a is None:
- a = nx.full(n, 1 / n)
+ a = nx.full(n, 1 / n, type_as=X_s)
if b is None:
- b = nx.full(m, 1 / m)
+ b = nx.full(m, 1 / m, type_as=X_s)
d = X_s.shape[1]
@@ -238,9 +238,9 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
X_t.shape[1]))
if a is None:
- a = nx.full(n, 1 / n)
+ a = nx.full(n, 1 / n, type_as=X_s)
if b is None:
- b = nx.full(m, 1 / m)
+ b = nx.full(m, 1 / m, type_as=X_s)
d = X_s.shape[1]