summaryrefslogtreecommitdiff
path: root/ot/gpu
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu')
-rw-r--r--ot/gpu/bregman.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 6714098..600ead4 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -90,14 +90,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
"""
- a = cp.asarray(a, dtype=np.float64)
- b = cp.asarray(b, dtype=np.float64)
- M = cp.asarray(M, dtype=np.float64)
+ a = cp.asarray(a)
+ b = cp.asarray(b)
+ M = cp.asarray(M)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = np.ones((M.shape[0],)) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = np.ones((M.shape[1],)) / M.shape[1]
# init data
Nini = len(a)