diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 14:30:44 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 14:30:44 +0200 |
commit | f45f7a68b221ec5b619b8fd8de797815a1eecf43 (patch) | |
tree | e3cc97cdf0c38e457303ceba32f7dadc20a12139 /ot/gpu/bregman.py | |
parent | d258c7d6936410cd78189445a0260d983f7684d6 (diff) |
pep8
Diffstat (limited to 'ot/gpu/bregman.py')
-rw-r--r-- | ot/gpu/bregman.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 912104c..6714098 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -8,12 +8,11 @@ Bregman projections for regularized OT with GPU # # License: MIT License -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations +import cupy as np # np used for matrix computation +import cupy as cp # cp used for cupy specific operations from . import utils - def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): """ @@ -159,7 +158,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, np.sum((v - vprev)**2) / np.sum((v)**2) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2=np.sum(u[:,None]*K*v[None,:],0) + tmp2 = np.sum(u[:, None] * K * v[None, :], 0) #tmp2=np.einsum('i,ij,j->j', u, K, v) err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: @@ -177,24 +176,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, if nbb: # return only loss #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) - res=np.empty(nbb) + res = np.empty(nbb) for i in range(nbb): - res[i]=np.sum(u[:,None,i]*(K*M)*v[None,:,i]) + res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) if to_numpy: - res=utils.to_np(res) + res = utils.to_np(res) if log: return res, log else: return res else: # return OT matrix - res=u.reshape((-1, 1)) * K * v.reshape((1, -1)) + res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) if to_numpy: - res=utils.to_np(res) + res = utils.to_np(res) if log: return res, log else: return res + # define sinkhorn as sinkhorn_knopp -sinkhorn=sinkhorn_knopp
\ No newline at end of file +sinkhorn = sinkhorn_knopp |