summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:05:09 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 15:05:09 +0200
commit75e78022d2df350ea220cee1b5e759ef9fc35a5b (patch)
tree81f6ac35917d23dbca1bc95b7c5296d1ac2175a0 /ot/gpu/bregman.py
parentf45f7a68b221ec5b619b8fd8de797815a1eecf43 (diff)
update tests
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r--ot/gpu/bregman.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 6714098..600ead4 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -90,14 +90,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
"""
- a = cp.asarray(a, dtype=np.float64)
- b = cp.asarray(b, dtype=np.float64)
- M = cp.asarray(M, dtype=np.float64)
+ a = cp.asarray(a)
+ b = cp.asarray(b)
+ M = cp.asarray(M)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = np.ones((M.shape[0],)) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = np.ones((M.shape[1],)) / M.shape[1]
# init data
Nini = len(a)