diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 10:07:00 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-09-24 10:07:00 +0200 |
commit | eb17e022fee209f3d363a6f8dcbb0064fccde1ad (patch) | |
tree | a8038a04034415dfd4c9e7eb91bf147639b56865 /ot/bregman.py | |
parent | 653fd0084c529bc74dabf93c68a9bdd5ac8f377a (diff) |
correct if error bug
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 58 |
1 files changed, 29 insertions, 29 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 1f9874e..8538c92 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -103,10 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - if method.lower() == 'greenkhorn': + elif method.lower() == 'greenkhorn': def sink(): return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log) elif method.lower() == 'sinkhorn_stabilized': def sink(): return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, @@ -417,17 +417,16 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) - -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False): +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): """ Solve the entropic regularization optimal transport problem and return the OT matrix - + The algorithm used is based on the paper - + Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration by Jason Altschuler, Jonathan Weed, Philippe Rigollet appeared at NIPS 2017 - + which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. The function solves the following optimization problem: @@ -499,21 +498,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log ot.optim.cg : General regularized OT """ - + i = 0 - + n = a.shape[0] m = b.shape[0] - + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - - u = np.ones(n)/n - v = np.ones(m)/m + + u = np.ones(n) / n + v = np.ones(m) / m G = np.diag(u)@K@np.diag(v) - + one_n = np.ones(n) one_m = np.ones(m) viol = G@one_m - a @@ -524,41 +523,42 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log log['v'] = v while i < numItermax and stopThr_val > stopThr: - i +=1 + i += 1 i_1 = np.argmax(np.abs(viol)) i_2 = np.argmax(np.abs(viol_2)) m_viol_1 = np.abs(viol[i_1]) m_viol_2 = np.abs(viol_2[i_2]) - stopThr_val = np.maximum(m_viol_1,m_viol_2) - + stopThr_val = np.maximum(m_viol_1, m_viol_2) + if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1]/(K[i_1,:]@v) - G[i_1,:] = u[i_1]*K[i_1,:]*v + u[i_1] = a[i_1] / (K[i_1, :]@v) + G[i_1, :] = u[i_1] * K[i_1, :] * v - viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1] - viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v) + viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1] + viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v) else: old_v = v[i_2] - v[i_2] = b[i_2]/(K[:,i_2].T@u) - G[:,i_2] = u*K[:,i_2]*v[i_2] + v[i_2] = b[i_2] / (K[:, i_2].T@u) + G[:, i_2] = u * K[:, i_2] * v[i_2] #aviol = (G@one_m - a) #aviol_2 = (G.T@one_n - b) - viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u - viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2] - + viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u + viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2] + #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) - + if log: log['u'] = u log['v'] = v - + if log: - return G,log + return G, log else: return G + def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ |