summaryrefslogtreecommitdiff
path: root/ot/gpu/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r--ot/gpu/bregman.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 4d4a8b7..f91f15f 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -4,10 +4,11 @@ Bregman projections for regularized OT with GPU
"""
import numpy as np
+import cudamat
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, cudamat=None):
+ log=False):
# init data
Nini = len(a)
Nfin = len(b)
@@ -74,7 +75,6 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log['u'] = u_GPU.asarray()
log['v'] = v_GPU.asarray()
- # print('err=',err,' cpt=',cpt)
K_GPU.mult_by_col(u_GPU, target=K_GPU)
K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
if log: