diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 13:55:55 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-07-24 13:55:55 +0200 |
commit | c0c959da8e62d57587ed36e8ba359ca095c5b423 (patch) | |
tree | 5d849205436a873c326b5c7ebe071c919fb82226 /ot | |
parent | 5cd6c0aae23a36fe27c188cdd18c7f0fba8a0360 (diff) |
speedup einsum constraint violation
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index c8e69ce..26b7b53 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, np.exp(K, out=K) # print(np.min(K)) - tmp = np.empty(K.shape, dtype=M.dtype) tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K @@ -379,11 +378,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - np.multiply(u.reshape(-1, 1), K, out=tmp) - np.multiply(tmp, v.reshape(1, -1), out=tmp) - np.sum(tmp, axis=0, out=tmp2) - tmp2 -= b - err = np.linalg.norm(tmp2)**2 + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j',u,K,v,out=tmp2) + err = np.linalg.norm(tmp2-b)**2 # violation of marginal if log: log['err'].append(err) |