summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:07:00 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:07:00 +0200
commiteb17e022fee209f3d363a6f8dcbb0064fccde1ad (patch)
treea8038a04034415dfd4c9e7eb91bf147639b56865 /ot/bregman.py
parent653fd0084c529bc74dabf93c68a9bdd5ac8f377a (diff)
correct if error bug
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 1f9874e..8538c92 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -103,10 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
- if method.lower() == 'greenkhorn':
+ elif method.lower() == 'greenkhorn':
def sink():
return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
def sink():
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -417,17 +417,16 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False):
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
-
+
The algorithm used is based on the paper
-
+
Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
appeared at NIPS 2017
-
+
which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
The function solves the following optimization problem:
@@ -499,21 +498,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
ot.optim.cg : General regularized OT
"""
-
+
i = 0
-
+
n = a.shape[0]
m = b.shape[0]
-
+
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
-
- u = np.ones(n)/n
- v = np.ones(m)/m
+
+ u = np.ones(n) / n
+ v = np.ones(m) / m
G = np.diag(u)@K@np.diag(v)
-
+
one_n = np.ones(n)
one_m = np.ones(m)
viol = G@one_m - a
@@ -524,41 +523,42 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log
log['v'] = v
while i < numItermax and stopThr_val > stopThr:
- i +=1
+ i += 1
i_1 = np.argmax(np.abs(viol))
i_2 = np.argmax(np.abs(viol_2))
m_viol_1 = np.abs(viol[i_1])
m_viol_2 = np.abs(viol_2[i_2])
- stopThr_val = np.maximum(m_viol_1,m_viol_2)
-
+ stopThr_val = np.maximum(m_viol_1, m_viol_2)
+
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1]/(K[i_1,:]@v)
- G[i_1,:] = u[i_1]*K[i_1,:]*v
+ u[i_1] = a[i_1] / (K[i_1, :]@v)
+ G[i_1, :] = u[i_1] * K[i_1, :] * v
- viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1]
- viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v)
+ viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1]
+ viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v)
else:
old_v = v[i_2]
- v[i_2] = b[i_2]/(K[:,i_2].T@u)
- G[:,i_2] = u*K[:,i_2]*v[i_2]
+ v[i_2] = b[i_2] / (K[:, i_2].T@u)
+ G[:, i_2] = u * K[:, i_2] * v[i_2]
#aviol = (G@one_m - a)
#aviol_2 = (G.T@one_n - b)
- viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u
- viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2]
-
+ viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u
+ viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2]
+
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
-
+
if log:
log['u'] = u
log['v'] = v
-
+
if log:
- return G,log
+ return G, log
else:
return G
+
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""