summaryrefslogtreecommitdiff
path: root/ot/lp
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-02 13:42:02 +0100
committerGitHub <noreply@github.com>2021-11-02 13:42:02 +0100
commita335324d008e8982be61d7ace937815a2bfa98f9 (patch)
tree83c7f637597f10f6f3d20b15532e53fc65b51f22 /ot/lp
parent0cb2b2efe901ed74c614046d250518769f870313 (diff)
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp')
-rw-r--r--ot/lp/__init__.py58
1 files changed, 36 insertions, 22 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index c6757d1..4e95ccf 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -691,10 +691,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
transportation matrix)
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- x_a = np.asarray(x_a, dtype=np.float64)
- x_b = np.asarray(x_b, dtype=np.float64)
+ a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
+ nx = get_backend(x_a, x_b)
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"
@@ -702,27 +700,43 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
"emd_1d should only be used with monodimensional data"
# if empty array given then use uniform distributions
- if a.ndim == 0 or len(a) == 0:
- a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
- if b.ndim == 0 or len(b) == 0:
- b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+ if a is None or a.ndim == 0 or len(a) == 0:
+ a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
+ if b is None or b.ndim == 0 or len(b) == 0:
+ b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
# ensure that same mass
- np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum')
- b=b*a.sum()/b.sum()
-
- x_a_1d = x_a.reshape((-1,))
- x_b_1d = x_b.reshape((-1,))
- perm_a = np.argsort(x_a_1d)
- perm_b = np.argsort(x_b_1d)
-
- G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
- x_a_1d[perm_a], x_b_1d[perm_b],
- metric=metric, p=p)
- G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
- shape=(a.shape[0], b.shape[0]))
+ np.testing.assert_almost_equal(
+ nx.sum(a, axis=0),
+ nx.sum(b, axis=0),
+ err_msg='a and b vector must have the same sum'
+ )
+ b = b * nx.sum(a) / nx.sum(b)
+
+ x_a_1d = nx.reshape(x_a, (-1,))
+ x_b_1d = nx.reshape(x_b, (-1,))
+ perm_a = nx.argsort(x_a_1d)
+ perm_b = nx.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(
+ nx.to_numpy(a[perm_a]),
+ nx.to_numpy(b[perm_b]),
+ nx.to_numpy(x_a_1d[perm_a]),
+ nx.to_numpy(x_b_1d[perm_b]),
+ metric=metric, p=p
+ )
+
+ G = nx.coo_matrix(
+ G_sorted,
+ perm_a[indices[:, 0]],
+ perm_b[indices[:, 1]],
+ shape=(a.shape[0], b.shape[0]),
+ type_as=x_a
+ )
if dense:
- G = G.toarray()
+ G = nx.todense(G)
+ elif str(nx) == "jax":
+ warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
log = {'cost': cost}
return G, log