From cb6bdc516697e3bad6776b897f22c8b6a22f13cd Mon Sep 17 00:00:00 2001 From: LeoGautheron Date: Wed, 11 Jul 2018 22:28:38 +0200 Subject: Speed-up Sinkhorn Speed-up in 3 places: - the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances - faster computation of K = np.exp(-M / reg) - faster computation of the error every 10 iterations Example with this little script: import time import numpy as np import ot rng = np.random.RandomState(0) transport = ot.da.SinkhornTransport() time1 = time.time() Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100) transport.fit(Xs=Xs, Xt=Xt) time2 = time.time() print("OT Computation Time {:6.2f} sec".format(time2-time1)) transport = ot.da.SinkhornLpl1Transport() transport.fit(Xs=Xs, ys=ys, Xt=Xt) time3 = time.time() print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2)) Before OT Computation Time 19.93 sec OT LpL1 Computation Time 133.43 sec After OT Computation Time 7.55 sec OT LpL1 Computation Time 82.25 sec --- ot/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'ot/utils.py') diff --git a/ot/utils.py b/ot/utils.py index 7dac283..5b052ac 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -13,6 +13,7 @@ import time import numpy as np from scipy.spatial.distance import cdist +from sklearn.metrics.pairwise import euclidean_distances import sys import warnings try: @@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'): """ if x2 is None: x2 = x1 - + if metric == "sqeuclidean": + return euclidean_distances(x1, x2, squared=True) return cdist(x1, x2, metric=metric) -- cgit v1.2.3