summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/bregman.py14
-rw-r--r--ot/utils.py4
2 files changed, 14 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b017c1a..55c44f6 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -344,8 +344,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# print(reg)
- K = np.exp(-M / reg)
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
# print(np.min(K))
+ tmp = np.empty(K.shape, dtype=M.dtype)
+ tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -373,8 +378,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- transp = u.reshape(-1, 1) * (K * v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ np.multiply(u.reshape(-1, 1), K, out=tmp)
+ np.multiply(tmp, v.reshape(1, -1), out=tmp)
+ np.sum(tmp, axis=0, out=tmp2)
+ tmp2 -= b
+ err = np.linalg.norm(tmp2)**2
if log:
log['err'].append(err)
diff --git a/ot/utils.py b/ot/utils.py
index 7dac283..5b052ac 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -13,6 +13,7 @@ import time
import numpy as np
from scipy.spatial.distance import cdist
+from sklearn.metrics.pairwise import euclidean_distances
import sys
import warnings
try:
@@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
"""
if x2 is None:
x2 = x1
-
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
return cdist(x1, x2, metric=metric)