diff options
author | LeoGautheron <leo_g_autheron@hotmail.fr> | 2018-07-11 22:28:38 +0200 |
---|---|---|
committer | LeoGautheron <leo_g_autheron@hotmail.fr> | 2018-07-11 22:28:38 +0200 |
commit | cb6bdc516697e3bad6776b897f22c8b6a22f13cd (patch) | |
tree | 7b040dc809893613390d202b41e539cc69972fe7 | |
parent | 39cbcd302c1d1e275c628d3bac073ec1f89596c6 (diff) |
Speed-up Sinkhorn
Speed-up in 3 places:
- the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances
- faster computation of K = np.exp(-M / reg)
- faster computation of the error every 10 iterations
Example with this little script:
import time
import numpy as np
import ot
rng = np.random.RandomState(0)
transport = ot.da.SinkhornTransport()
time1 = time.time()
Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100)
transport.fit(Xs=Xs, Xt=Xt)
time2 = time.time()
print("OT Computation Time {:6.2f} sec".format(time2-time1))
transport = ot.da.SinkhornLpl1Transport()
transport.fit(Xs=Xs, ys=ys, Xt=Xt)
time3 = time.time()
print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2))
Before
OT Computation Time 19.93 sec
OT LpL1 Computation Time 133.43 sec
After
OT Computation Time 7.55 sec
OT LpL1 Computation Time 82.25 sec
-rw-r--r-- | ot/bregman.py | 14 | ||||
-rw-r--r-- | ot/utils.py | 4 |
2 files changed, 14 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index b017c1a..55c44f6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -344,8 +344,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # print(reg) - K = np.exp(-M / reg) + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + # print(np.min(K)) + tmp = np.empty(K.shape, dtype=M.dtype) + tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -373,8 +378,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: - transp = u.reshape(-1, 1) * (K * v) - err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 + np.multiply(u.reshape(-1, 1), K, out=tmp) + np.multiply(tmp, v.reshape(1, -1), out=tmp) + np.sum(tmp, axis=0, out=tmp2) + tmp2 -= b + err = np.linalg.norm(tmp2)**2 if log: log['err'].append(err) diff --git a/ot/utils.py b/ot/utils.py index 7dac283..5b052ac 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -13,6 +13,7 @@ import time import numpy as np from scipy.spatial.distance import cdist +from sklearn.metrics.pairwise import euclidean_distances import sys import warnings try: @@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'): """ if x2 is None: x2 = x1 - + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True) return cdist(x1, x2, metric=metric) |