summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 17:35:36 +0200
committerGitHub <noreply@github.com>2021-10-25 17:35:36 +0200
commit76450dddf8dd62b9714b72e99ae075516246d433 (patch)
tree67de8de1c185cc8e7fc33a1fc0613015824d1fbb /ot/lp/__init__.py
parent7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff)
[MRG] Backend for optim (#282)
* Backend for optim * Bug solve * Doc update * backend tests now with fixture * Unused imports removed * Docs * Docs * Docs * Outer product backend docs * Prettier docs * Pep8 * Mistakes corrected Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index b907b10..c6757d1 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a0, b0, M0 = a, b, M
nx = get_backend(M0, a0, b0)
-
+
# convert to numpy
M = nx.to_numpy(M)
a = nx.to_numpy(a)
b = nx.to_numpy(b)
-
+
# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)