diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-11-04 15:19:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-04 15:19:57 +0100 |
commit | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch) | |
tree | 22a447a1dbb1505b18f9e426e1761cf6b328b6eb /ot/lp | |
parent | 2fe69eb130827560ada704bc25998397c4357821 (diff) |
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu
* pep 8 of couse
* debug torch
* jax with gpu
* device put
* device put
* it works
* emd1d and emd2_1d working
* emd_1d and emd2_1d done
* cleanup
* of course
* should work on gpu now
* tests done+ pep8
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/solver_1d.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa..8b4d0c3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G |