diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-02 13:42:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-02 13:42:02 +0100 |
commit | a335324d008e8982be61d7ace937815a2bfa98f9 (patch) | |
tree | 83c7f637597f10f6f3d20b15532e53fc65b51f22 /ot/lp | |
parent | 0cb2b2efe901ed74c614046d250518769f870313 (diff) |
[MRG] Backend for gromov (#294)
* bregman: small correction
* gromov backend first draft
* Removing decorators
* Reworked casting method
* Bug solve
* Removing casting
* Bug solve
* toarray renamed todense ; expand_dims removed
* Warning (jax not supporting sparse matrix) moved
* Mistake corrected
* test backend
* Sparsity test for older versions of pytorch
* Trying pytorch/1.10
* Attempt to correct torch sparse bug
* Backend version of gromov tests
* Random state introduced for remaining gromov functions
* review changes
* code coverage
* Docs (first draft, to be continued)
* Gromov docs
* Prettified docs
* mistake corrected in the docs
* little change
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp')
-rw-r--r-- | ot/lp/__init__.py | 58 |
1 files changed, 36 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c6757d1..4e95ccf 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -691,10 +691,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - x_a = np.asarray(x_a, dtype=np.float64) - x_b = np.asarray(x_b, dtype=np.float64) + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" @@ -702,27 +700,43 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, "emd_1d should only be used with monodimensional data" # if empty array given then use uniform distributions - if a.ndim == 0 or len(a) == 0: - a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if b.ndim == 0 or len(b) == 0: - b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass - np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() - - x_a_1d = x_a.reshape((-1,)) - x_b_1d = x_b.reshape((-1,)) - perm_a = np.argsort(x_a_1d) - perm_b = np.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], - x_a_1d[perm_a], x_b_1d[perm_b], - metric=metric, p=p) - G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), - shape=(a.shape[0], b.shape[0])) + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) if dense: - G = G.toarray() + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {'cost': cost} return G, log |