summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-04-23 11:10:14 +0200
committerGitHub <noreply@github.com>2019-04-23 11:10:14 +0200
commitc5108efc7b6702e1af3928bef1032e6b37734d1c (patch)
tree5ed685351f4349c83e8d3f7d56e0f9b0550aba5b
parent2384380536e3cc405e4db9f4b31cb48d309f257c (diff)
parent17fa4f9a8cf7ffd1a58853b4091cee0238a1100b (diff)
Merge pull request #80 from kilianFatras/master
add empirical sinkhorn and sinkhorn divergence functions
-rw-r--r--README.md5
-rw-r--r--examples/plot_OT_2D_samples.py26
-rw-r--r--ot/bregman.py301
-rw-r--r--ot/stochastic.py3
-rw-r--r--test/test_bregman.py67
5 files changed, 401 insertions, 1 deletions
diff --git a/README.md b/README.md
index b068131..dd34a97 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,8 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:
* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
-* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cupy).
+* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2], stabilized version [9][10] and greedy Sinkhorn [22] with optional GPU implementation (requires cupy).
+* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
@@ -230,3 +231,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
+
+[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index bb952a0..63126ba 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -100,3 +101,28 @@ pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index 013bc33..dc43834 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -5,10 +5,12 @@ Bregman projections for regularized OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
import numpy as np
+from .utils import unif, dist
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1296,3 +1298,302 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1), log
else:
return np.sum(K0, axis=1)
+
+
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem and return the
+ OT matrix from empirical data
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
+ >>> print(emp_sinkhorn)
+ >>> [[4.99977301e-01 2.26989344e-05]
+ [2.26989344e-05 4.99977301e-01]]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
+
+
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem from empirical
+ data and return the OT loss
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
+ >>> print(loss_sinkhorn)
+ >>> [4.53978687e-05]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = unif(np.shape(X_s)[0])
+ if b is None:
+ b = unif(np.shape(X_t)[0])
+
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss
+
+
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Compute the sinkhorn divergence loss from empirical data
+
+ The function solves the following optimization problems and return the
+ sinkhorn divergence :math:`S`:
+
+ .. math::
+
+ W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+
+ W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+
+ S &= W - 1/2 * (W_a + W_b)
+
+ .. math::
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+
+ \gamma_a 1 = a
+
+ \gamma_a^T 1= a
+
+ \gamma_a\geq 0
+
+ \gamma_b 1 = b
+
+ \gamma_b^T 1= b
+
+ \gamma_b\geq 0
+ where :
+
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`a` and :math:`b` are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 4
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
+ >>> print(emp_sinkhorn_div)
+ >>> [2.99977435]
+
+
+ References
+ ----------
+
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ '''
+ if log:
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+
+ log = {}
+ log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
+ log['sinkhorn_loss_a'] = sinkhorn_loss_a
+ log['sinkhorn_loss_b'] = sinkhorn_loss_b
+ log['log_sinkhorn_ab'] = log_ab
+ log['log_sinkhorn_a'] = log_a
+ log['log_sinkhorn_b'] = log_b
+
+ return max(0, sinkhorn_div), log
+
+ else:
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 0db39c8..85c4230 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -348,8 +348,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
s.t. \gamma 1 = a
+
\gamma^T 1= b
+
\gamma \geq 0
Where :
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 90eaf27..7f4972c 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -1,6 +1,7 @@
"""Tests for module bregman on OT with bregman projections """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -187,3 +188,69 @@ def test_unmix():
ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
+
+
+def test_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence():
+ #Test sinkhorn divergence
+ n = 10
+ a = ot.unif(n)
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_s = ot.dist(X_s, X_s)
+ M_t = ot.dist(X_t, X_t)
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
+ sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
+
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
+ sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
+ sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
+ sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
+
+ # check constratints
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ np.testing.assert_allclose(
+ emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn