diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 15:05:09 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 15:05:09 +0200 |
commit | 75e78022d2df350ea220cee1b5e759ef9fc35a5b (patch) | |
tree | 81f6ac35917d23dbca1bc95b7c5296d1ac2175a0 /ot/gpu/bregman.py | |
parent | f45f7a68b221ec5b619b8fd8de797815a1eecf43 (diff) |
update tests
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r-- | ot/gpu/bregman.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 6714098..600ead4 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -90,14 +90,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, """ - a = cp.asarray(a, dtype=np.float64) - b = cp.asarray(b, dtype=np.float64) - M = cp.asarray(M, dtype=np.float64) + a = cp.asarray(a) + b = cp.asarray(b) + M = cp.asarray(M) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = np.ones((M.shape[0],)) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = np.ones((M.shape[1],)) / M.shape[1] # init data Nini = len(a) |