diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-07-18 11:34:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-18 11:34:37 +0200 |
commit | 5cd6c0aae23a36fe27c188cdd18c7f0fba8a0360 (patch) | |
tree | d7b0968d9d50f0e40d225aeddc85b10d8e6c4cca /ot/bregman.py | |
parent | 7c5c8803b2bdb67545783db3321b9d5a81a063d6 (diff) | |
parent | 0764e356325df7e18f72c0ff468bfa8f8ee35059 (diff) |
Merge pull request #57 from LeoGautheron/master
Speed-up Sinkhorn
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index b017c1a..c8e69ce 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -344,8 +344,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # print(reg) - K = np.exp(-M / reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + # print(np.min(K)) + tmp = np.empty(K.shape, dtype=M.dtype) + tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -373,8 +379,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 + np.multiply(u.reshape(-1, 1), K, out=tmp) + np.multiply(tmp, v.reshape(1, -1), out=tmp) + np.sum(tmp, axis=0, out=tmp2) + tmp2 -= b + err = np.linalg.norm(tmp2)**2 if log: log['err'].append(err) |