From 653fd0084c529bc74dabf93c68a9bdd5ac8f377a Mon Sep 17 00:00:00 2001 From: alain Date: Mon, 24 Sep 2018 09:05:47 +0200 Subject: adding greenkhorn --- ot/bregman.py | 151 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 1 deletion(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index c755f51..1f9874e 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations @@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + if method.lower() == 'greenkhorn': + def sink(): + return greenkhorn(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log) elif method.lower() == 'sinkhorn_stabilized': def sink(): return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, @@ -197,6 +201,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + See Also @@ -204,6 +210,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] + ot.bregman.greenkhorn : Greenkhorn [21] ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] @@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) + +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False): + """ + Solve the entropic regularization optimal transport problem and return the OT matrix + + The algorithm used is based on the paper + + Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration + by Jason Altschuler, Jonathan Weed, Philippe Rigollet + appeared at NIPS 2017 + + which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed M if b is a matrix (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + [21] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + i = 0 + + n = a.shape[0] + m = b.shape[0] + + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + u = np.ones(n)/n + v = np.ones(m)/m + G = np.diag(u)@K@np.diag(v) + + one_n = np.ones(n) + one_m = np.ones(m) + viol = G@one_m - a + viol_2 = G.T@one_n - b + stopThr_val = 1 + if log: + log['u'] = u + log['v'] = v + + while i < numItermax and stopThr_val > stopThr: + i +=1 + i_1 = np.argmax(np.abs(viol)) + i_2 = np.argmax(np.abs(viol_2)) + m_viol_1 = np.abs(viol[i_1]) + m_viol_2 = np.abs(viol_2[i_2]) + stopThr_val = np.maximum(m_viol_1,m_viol_2) + + if m_viol_1 > m_viol_2: + old_u = u[i_1] + u[i_1] = a[i_1]/(K[i_1,:]@v) + G[i_1,:] = u[i_1]*K[i_1,:]*v + + viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1] + viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v) + + else: + old_v = v[i_2] + v[i_2] = b[i_2]/(K[:,i_2].T@u) + G[:,i_2] = u*K[:,i_2]*v[i_2] + #aviol = (G@one_m - a) + #aviol_2 = (G.T@one_n - b) + viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u + viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2] + + #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + + if log: + log['u'] = u + log['v'] = v + + if log: + return G,log + else: + return G + def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ -- cgit v1.2.3 From eb17e022fee209f3d363a6f8dcbb0064fccde1ad Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:07:00 +0200 Subject: correct if error bug --- ot/bregman.py | 58 +++++++++++++++++++++++++++++----------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 1f9874e..8538c92 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -103,10 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - if method.lower() == 'greenkhorn': + elif method.lower() == 'greenkhorn': def sink(): return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log) elif method.lower() == 'sinkhorn_stabilized': def sink(): return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, @@ -417,17 +417,16 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) - -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False): +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): """ Solve the entropic regularization optimal transport problem and return the OT matrix - + The algorithm used is based on the paper - + Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration by Jason Altschuler, Jonathan Weed, Philippe Rigollet appeared at NIPS 2017 - + which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. The function solves the following optimization problem: @@ -499,21 +498,21 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log ot.optim.cg : General regularized OT """ - + i = 0 - + n = a.shape[0] m = b.shape[0] - + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - - u = np.ones(n)/n - v = np.ones(m)/m + + u = np.ones(n) / n + v = np.ones(m) / m G = np.diag(u)@K@np.diag(v) - + one_n = np.ones(n) one_m = np.ones(m) viol = G@one_m - a @@ -524,41 +523,42 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log log['v'] = v while i < numItermax and stopThr_val > stopThr: - i +=1 + i += 1 i_1 = np.argmax(np.abs(viol)) i_2 = np.argmax(np.abs(viol_2)) m_viol_1 = np.abs(viol[i_1]) m_viol_2 = np.abs(viol_2[i_2]) - stopThr_val = np.maximum(m_viol_1,m_viol_2) - + stopThr_val = np.maximum(m_viol_1, m_viol_2) + if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1]/(K[i_1,:]@v) - G[i_1,:] = u[i_1]*K[i_1,:]*v + u[i_1] = a[i_1] / (K[i_1, :]@v) + G[i_1, :] = u[i_1] * K[i_1, :] * v - viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1] - viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v) + viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1] + viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v) else: old_v = v[i_2] - v[i_2] = b[i_2]/(K[:,i_2].T@u) - G[:,i_2] = u*K[:,i_2]*v[i_2] + v[i_2] = b[i_2] / (K[:, i_2].T@u) + G[:, i_2] = u * K[:, i_2] * v[i_2] #aviol = (G@one_m - a) #aviol_2 = (G.T@one_n - b) - viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u - viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2] - + viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u + viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2] + #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) - + if log: log['u'] = u log['v'] = v - + if log: - return G,log + return G, log else: return G + def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ -- cgit v1.2.3 From 7ffd4fef3260e086b0b1ed050f5cb4b83195b122 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:14:44 +0200 Subject: remove @ for python compatibility+ comments alexandre --- ot/bregman.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 8538c92..faa6365 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -480,7 +480,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.sinkhorn(a,b,M,1) + >>> ot.bregman.greenkhorn(a,b,M,1) array([[ 0.36552929, 0.13447071], [ 0.13447071, 0.36552929]]) @@ -505,18 +505,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= m = b.shape[0] # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) + K = np.empty_like(M) np.divide(M, -reg, out=K) np.exp(K, out=K) - u = np.ones(n) / n - v = np.ones(m) / m - G = np.diag(u)@K@np.diag(v) + u = np.full(n, 1. / n) + v = np.full(m, 1. / m) + G = u[:, np.newaxis] * K * v[np.newaxis, :] one_n = np.ones(n) one_m = np.ones(m) - viol = G@one_m - a - viol_2 = G.T@one_n - b + viol = G.sum(1) - a + viol_2 = G.sum(0) - b stopThr_val = 1 if log: log['u'] = u @@ -532,26 +532,26 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1] / (K[i_1, :]@v) + u[i_1] = a[i_1] / (K[i_1, :].dot(v)) G[i_1, :] = u[i_1] * K[i_1, :] * v - viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1] + viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v) else: old_v = v[i_2] - v[i_2] = b[i_2] / (K[:, i_2].T@u) + v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) G[:, i_2] = u * K[:, i_2] * v[i_2] #aviol = (G@one_m - a) #aviol_2 = (G.T@one_n - b) viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u - viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2] + viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) - if log: - log['u'] = u - log['v'] = v + if log: + log['u'] = u + log['v'] = v if log: return G, log -- cgit v1.2.3 From 24a53ef2dba0a43c282f6b31937c3e7901df7930 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:17:21 +0200 Subject: add contributor --- README.md | 1 + ot/bregman.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'ot') diff --git a/README.md b/README.md index 6a6686c..4d824ce 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,7 @@ The contributors to this library are: * [Antoine Rolet](https://arolet.github.io/) * Erwan Vautier (Gromov-Wasserstein) * [Kilian Fatras](https://kilianfatras.github.io/) +* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/bregman.py b/ot/bregman.py index faa6365..97027e8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -536,7 +536,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= G[i_1, :] = u[i_1] * K[i_1, :] * v viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] - viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v) + viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v) else: old_v = v[i_2] @@ -544,7 +544,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= G[:, i_2] = u * K[:, i_2] * v[i_2] #aviol = (G@one_m - a) #aviol_2 = (G.T@one_n - b) - viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u + viol += (-old_v + v[i_2]) * K[:, i_2] * u viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) -- cgit v1.2.3 From 55e8392993919d3c67538756663abd943d3bb491 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:19:18 +0200 Subject: remove unused variable --- ot/bregman.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 97027e8..6e446a1 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -513,8 +513,6 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= v = np.full(m, 1. / m) G = u[:, np.newaxis] * K * v[np.newaxis, :] - one_n = np.ones(n) - one_m = np.ones(m) viol = G.sum(1) - a viol_2 = G.sum(0) - b stopThr_val = 1 -- cgit v1.2.3 From 1d494107611c2e6e2249b7a624e64cec6357b4bd Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:23:02 +0200 Subject: implement for loop --- ot/bregman.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 6e446a1..05f7c75 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -520,7 +520,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= log['u'] = u log['v'] = v - while i < numItermax and stopThr_val > stopThr: + for i in range(numItermax): i += 1 i_1 = np.argmax(np.abs(viol)) i_2 = np.argmax(np.abs(viol_2)) @@ -547,6 +547,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + if stopThr_val <= stopThr: + break + else: + print('Warning: Algorithm did not converge') + if log: log['u'] = u log['v'] = v -- cgit v1.2.3 From 75fe96c183852971bb7be1da39af202b9f7d6e6c Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:25:25 +0200 Subject: remove i+1 --- ot/bregman.py | 1 - 1 file changed, 1 deletion(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 05f7c75..1f5150a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -521,7 +521,6 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= log['v'] = v for i in range(numItermax): - i += 1 i_1 = np.argmax(np.abs(viol)) i_2 = np.argmax(np.abs(viol_2)) m_viol_1 = np.abs(viol[i_1]) -- cgit v1.2.3 From dee6d6e16f6e5d328bc590089cf99ef586d7ca0f Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:34:32 +0200 Subject: correct reference number in doc --- README.md | 2 +- ot/bregman.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'ot') diff --git a/README.md b/README.md index 1c8114a..16fa153 100644 --- a/README.md +++ b/README.md @@ -232,4 +232,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. -[21] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 +[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 diff --git a/ot/bregman.py b/ot/bregman.py index 418de57..fd04fa4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -489,7 +489,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - [21] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 See Also -- cgit v1.2.3 From 1b24b1fd60a7126cd1646525ac5d7cf25f382a3a Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 24 Sep 2018 10:40:15 +0200 Subject: remove variable i initialization --- ot/bregman.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index fd04fa4..d1057ff 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -499,8 +499,6 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= """ - i = 0 - n = a.shape[0] m = b.shape[0] -- cgit v1.2.3