summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md36
-rw-r--r--README.md7
-rw-r--r--docs/source/readme.rst38
-rw-r--r--ot/bregman.py151
-rw-r--r--test/test_bregman.py3
5 files changed, 229 insertions, 6 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000..7f8acda
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,36 @@
+---
+name: Bug report
+about: Create a report to help us improve POT
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Steps to reproduce the behavior:
+1 ...
+2.
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Screenshots**
+If applicable, add screenshots to help explain your problem.
+
+**Desktop (please complete the following information):**
+ - OS: [e.g. MacOSX, Windows, Ubuntu]
+ - Python version [2.7,3.6]
+- How was POT installed [source, pip, conda]
+
+Output of the following code snippet:
+```python
+import platform; print(platform.platform())
+import sys; print("Python", sys.version)
+import numpy; print("NumPy", numpy.__version__)
+import scipy; print("SciPy", scipy.__version__)
+import ot; print("POT", ot.__version__)
+```
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/README.md b/README.md
index e49c8da..0540723 100644
--- a/README.md
+++ b/README.md
@@ -14,10 +14,10 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
-* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cupy).
+* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cupy).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
-* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
+* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
@@ -161,6 +161,7 @@ The contributors to this library are:
* [Antoine Rolet](https://arolet.github.io/)
* Erwan Vautier (Gromov-Wasserstein)
* [Kilian Fatras](https://kilianfatras.github.io/)
+* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -226,3 +227,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
+
+[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 5d37f64..a839231 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -13,13 +13,14 @@ It provides the following solvers:
- OT Network Flow solver for the linear program/ Earth Movers Distance
[1].
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
- and stabilized version [9][10] with optional GPU implementation
- (requires cudamat).
+ and stabilized version [9][10] and greedy SInkhorn [22] with optional
+ GPU implementation (requires cudamat).
- Smooth optimal transport solvers (dual and semi-dual) for KL and
squared L2 regularizations [17].
- Non regularized Wasserstein barycenters [16] with LP solver (only
small scale).
-- Bregman projections for Wasserstein barycenter [3] and unmixing [4].
+- Bregman projections for Wasserstein barycenter [3], convolutional
+ barycenter [21] and unmixing [4].
- Optimal transport for domain adaptation with group lasso
regularization [5]
- Conditional gradient [6] and Generalized conditional gradient for
@@ -29,6 +30,9 @@ It provides the following solvers:
pymanopt).
- Gromov-Wasserstein distances and barycenters ([13] and regularized
[12])
+- Stochastic Optimization for Large-scale Optimal Transport (semi-dual
+ problem [18] and dual problem [19])
+- Non regularized free support Wasserstein barycenters [20].
Some demonstrations (both in Python and Jupyter Notebook format) are
available in the examples folder.
@@ -219,6 +223,9 @@ The contributors to this library are:
- `Stanislas Chambon <https://slasnista.github.io/>`__
- `Antoine Rolet <https://arolet.github.io/>`__
- Erwan Vautier (Gromov-Wasserstein)
+- `Kilian Fatras <https://kilianfatras.github.io/>`__
+- `Alain
+ Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__
This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
@@ -334,6 +341,31 @@ Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
the Twenty-First International Conference on Artificial Intelligence and
Statistics (AISTATS).
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
+Optimization for Large-scale Optimal
+Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
+Information Processing Systems (2016).
+
+[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
+A.& Blondel, M. `Large-scale Optimal Transport and Mapping
+Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
+Conference on Learning Representation (2018)
+
+[20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
+Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
+International Conference in Machine Learning
+
+[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
+Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
+Efficient optimal transportation on geometric
+domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
+Transactions on Graphics (TOG), 34(4), 66.
+
+[22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
+approximation algorithms for optimal transport via Sinkhorn
+iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
+Advances in Neural Information Processing Systems (NIPS) 31
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
diff --git a/ot/bregman.py b/ot/bregman.py
index 35e51f8..d1057ff 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
'sinkhorn_epsilon_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
@@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ elif method.lower() == 'greenkhorn':
+ def sink():
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
def sink():
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
@@ -197,6 +201,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
See Also
@@ -204,6 +210,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
+ ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
@@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+ """
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper
+
+ Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
+ by Jason Altschuler, Jonathan Weed, Philippe Rigollet
+ appeared at NIPS 2017
+
+ which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.bregman.greenkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ n = a.shape[0]
+ m = b.shape[0]
+
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ u = np.full(n, 1. / n)
+ v = np.full(m, 1. / m)
+ G = u[:, np.newaxis] * K * v[np.newaxis, :]
+
+ viol = G.sum(1) - a
+ viol_2 = G.sum(0) - b
+ stopThr_val = 1
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ for i in range(numItermax):
+ i_1 = np.argmax(np.abs(viol))
+ i_2 = np.argmax(np.abs(viol_2))
+ m_viol_1 = np.abs(viol[i_1])
+ m_viol_2 = np.abs(viol_2[i_2])
+ stopThr_val = np.maximum(m_viol_1, m_viol_2)
+
+ if m_viol_1 > m_viol_2:
+ old_u = u[i_1]
+ u[i_1] = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = u[i_1] * K[i_1, :] * v
+
+ viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+
+ else:
+ old_v = v[i_2]
+ v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * v[i_2]
+ #aviol = (G@one_m - a)
+ #aviol_2 = (G.T@one_n - b)
+ viol += (-old_v + v[i_2]) * K[:, i_2] * u
+ viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
+
+ #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+
+ if stopThr_val <= stopThr:
+ break
+ else:
+ print('Warning: Algorithm did not converge')
+
+ if log:
+ log['u'] = u
+ log['v'] = v
+
+ if log:
+ return G, log
+ else:
+ return G
+
+
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
diff --git a/test/test_bregman.py b/test/test_bregman.py
index a141078..b0c3358 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -71,11 +71,14 @@ def test_sinkhorn_variants():
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
+ G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, Gerr)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
def test_bary():