diff options
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r-- | ot/gpu/bregman.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2c3e317..7881c65 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -82,7 +82,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ + """ # init data Nini = len(a) Nfin = len(b) @@ -92,11 +92,11 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, # we assume that no distances are null except those of the diagonal of # distances - u = (np.ones(Nini)/Nini).reshape((Nini, 1)) + u = (np.ones(Nini) / Nini).reshape((Nini, 1)) u_GPU = cudamat.CUDAMatrix(u) a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1))) ones_GPU = cudamat.empty(u_GPU.shape).assign(1) - v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1)) + v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1)) v_GPU = cudamat.CUDAMatrix(v) b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1))) @@ -121,7 +121,7 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU) if (np.any(KtransposeU_GPU.asarray() == 0) or - not u_GPU.allfinite() or not v_GPU.allfinite()): + not u_GPU.allfinite() or not v_GPU.allfinite()): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -142,7 +142,8 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, if verbose: if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19) + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt += 1 if log: |