summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 15:19:57 +0100
committerGitHub <noreply@github.com>2021-11-04 15:19:57 +0100
commit0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch)
tree22a447a1dbb1505b18f9e426e1761cf6b328b6eb /ot/lp
parent2fe69eb130827560ada704bc25998397c4357821 (diff)
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/solver_1d.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 42554aa..8b4d0c3 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
# ensure that same mass
np.testing.assert_almost_equal(
- nx.sum(a, axis=0),
- nx.sum(b, axis=0),
+ nx.to_numpy(nx.sum(a, axis=0)),
+ nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum'
)
b = b * nx.sum(a) / nx.sum(b)
@@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_b = nx.argsort(x_b_1d)
G_sorted, indices, cost = emd_1d_sorted(
- nx.to_numpy(a[perm_a]),
- nx.to_numpy(b[perm_b]),
- nx.to_numpy(x_a_1d[perm_a]),
- nx.to_numpy(x_b_1d[perm_b]),
+ nx.to_numpy(a[perm_a]).astype(np.float64),
+ nx.to_numpy(b[perm_b]).astype(np.float64),
+ nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
+ nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
metric=metric, p=p
)
@@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
elif str(nx) == "jax":
warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
- log = {'cost': cost}
+ log = {'cost': nx.from_numpy(cost, type_as=x_a)}
return G, log
return G