summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile11
-rw-r--r--ot/bregman.py15
2 files changed, 16 insertions, 10 deletions
diff --git a/Makefile b/Makefile
index 1abc6e9..84a644b 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,7 @@
PYTHON=python3
+branch := $(shell git symbolic-ref --short -q HEAD)
help :
@echo "The following make targets are available:"
@@ -57,6 +58,16 @@ rdoc :
notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
+bench :
+ @git stash >/dev/null 2>&1
+ @echo 'Branch master'
+ @git checkout master >/dev/null 2>&1
+ python3 $(script)
+ @echo 'Branch $(branch)'
+ @git checkout $(branch) >/dev/null 2>&1
+ python3 $(script)
+ @git stash apply >/dev/null 2>&1
+
autopep8 :
autopep8 -ir test ot examples --jobs -1
diff --git a/ot/bregman.py b/ot/bregman.py
index c8e69ce..c755f51 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
np.exp(K, out=K)
# print(np.min(K))
- tmp = np.empty(K.shape, dtype=M.dtype)
tmp2 = np.empty(b.shape, dtype=M.dtype)
Kp = (1 / a).reshape(-1, 1) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
+
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
- np.multiply(u.reshape(-1, 1), K, out=tmp)
- np.multiply(tmp, v.reshape(1, -1), out=tmp)
- np.sum(tmp, axis=0, out=tmp2)
- tmp2 -= b
- err = np.linalg.norm(tmp2)**2
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ np.einsum('i,ij,j->j', u, K, v, out=tmp2)
+ err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)
@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else: