diff options
author | Kilian <kilian.fatras@gmail.com> | 2018-08-29 14:22:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-29 14:22:40 -0700 |
commit | 15f4b29a91fda1dbd221e6e0a3443431d3d69257 (patch) | |
tree | 82c33ee8b09112b6a67ed614e370156e4144628f /ot/bregman.py | |
parent | 63b34bf012076eb89ed112122fdaa65667464ae7 (diff) | |
parent | 5180023fc49d15ad83faccc5674d5966fe9a0385 (diff) |
Merge branch 'master' into stochastic_OT
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index b017c1a..c755f51 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -344,8 +344,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # print(reg) - K = np.exp(-M / reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + # print(np.min(K)) + tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -353,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v + KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.dot(Kp, v) @@ -373,8 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j', u, K, v, out=tmp2) + err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: log['err'].append(err) @@ -389,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if nbb: # return only loss - res = np.zeros((nbb)) - for i in range(nbb): - res[i] = np.sum( - u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M) + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: |