diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-10-25 17:35:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-25 17:35:36 +0200 |
commit | 76450dddf8dd62b9714b72e99ae075516246d433 (patch) | |
tree | 67de8de1c185cc8e7fc33a1fc0613015824d1fbb /ot/lp/__init__.py | |
parent | 7a65086dd340265d0223eb8ffb5c9a5152a82dff (diff) |
[MRG] Backend for optim (#282)
* Backend for optim
* Bug solve
* Doc update
* backend tests now with fixture
* Unused imports removed
* Docs
* Docs
* Docs
* Outer product backend docs
* Prettier docs
* Pep8
* Mistakes corrected
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b907b10..c6757d1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -281,12 +281,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a0, b0, M0 = a, b, M nx = get_backend(M0, a0, b0) - + # convert to numpy M = nx.to_numpy(M) a = nx.to_numpy(a) b = nx.to_numpy(b) - + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) |