summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index b907b10..c6757d1 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a0, b0, M0 = a, b, M
nx = get_backend(M0, a0, b0)
-
+
# convert to numpy
M = nx.to_numpy(M)
a = nx.to_numpy(a)
b = nx.to_numpy(b)
-
+
# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)