summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-03-29 12:41:43 +0100
committerKilian Fatras <kilianfatras@Kilians-MacBook-Air.local>2019-03-29 12:41:43 +0100
commita2545b5a503c95c9bf07948929b77e9c3f4f28d3 (patch)
tree84bc0c169c1121bdff56e77c2c6cc88a68efba67
parent2384380536e3cc405e4db9f4b31cb48d309f257c (diff)
add empirical sinkhorn and sikhorn divergence functions
-rw-r--r--README.md2
-rw-r--r--examples/plot_OT_2D_samples.py26
-rw-r--r--ot/bregman.py269
-rw-r--r--test/test_bregman.py57
4 files changed, 354 insertions, 0 deletions
diff --git a/README.md b/README.md
index b068131..dbd93fc 100644
--- a/README.md
+++ b/README.md
@@ -230,3 +230,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
+
+[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index bb952a0..63126ba 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -100,3 +101,28 @@ pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index 013bc33..f1b18f8 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -5,6 +5,7 @@ Bregman projections for regularized OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -1296,3 +1297,271 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1), log
else:
return np.sum(K0, axis=1)
+
+
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem and return the
+ OT matrix from empirical data
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
+ >>> print(emp_sinkhorn)
+ >>> [[4.99977301e-01 2.26989344e-05]
+ [2.26989344e-05 4.99977301e-01]]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = ot.unif(np.shape(X_s)[0])
+ if b is None:
+ b = ot.unif(np.shape(X_t)[0])
+ M = ot.dist(X_s, X_t, metric=metric)
+ if log == False:
+ pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
+
+ if log == True:
+ pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+
+
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Solve the entropic regularization optimal transport problem from empirical
+ data and return the OT loss
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 2
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
+ >>> print(loss_sinkhorn)
+ >>> [4.53978687e-05]
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ '''
+
+ if a is None:
+ a = ot.unif(np.shape(X_s)[0])
+ if b is None:
+ b = ot.unif(np.shape(X_t)[0])
+
+ M = ot.dist(X_s, X_t, metric=metric)
+ if log == False:
+ sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss
+
+ if log == True:
+ sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_loss, log
+
+
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+ '''
+ Compute the sinkhorn divergence loss from empirical data
+
+ The function solves the following optimization problem:
+
+ .. math::
+ S = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) -
+ \min_\gamma_a <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) -
+ \min_\gamma_b <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+
+ \gamma_a 1 = a
+
+ \gamma_a^T 1= a
+
+ \gamma_a\geq 0
+
+ \gamma_b 1 = b
+
+ \gamma_b^T 1= b
+
+ \gamma_b\geq 0
+ where :
+
+ - M (resp. :math:`M_a, M_b) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+
+ Parameters
+ ----------
+ X_s : np.ndarray (ns, d)
+ samples in the source domain
+ X_t : np.ndarray (nt, d)
+ samples in the target domain
+ reg : float
+ Regularization term >0
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Regularized optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_s = 2
+ >>> n_t = 4
+ >>> reg = 0.1
+ >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
+ >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
+ >>> print(emp_sinkhorn_div)
+ >>> [2.99977435]
+
+
+ References
+ ----------
+
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ '''
+
+ sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
+ empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) -
+ empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs))
+ return max(0, sinkhorn_div)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 90eaf27..b890df1 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -1,6 +1,7 @@
"""Tests for module bregman on OT with bregman projections """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -187,3 +188,59 @@ def test_unmix():
ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
+
+
+def test_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+ M = ot.dist(X_s, X_t)
+ M_e = ot.dist(X_s, X_t, metric='euclidean')
+
+ rng = np.random.RandomState(0)
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+
+ G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_e = ot.sinkhorn(a, b, M_e, 1)
+
+ loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence():
+ #Test sinkhorn divergence
+ n = 10
+ a = ot.unif(n)
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_s = ot.dist(X_s, X_s)
+ M_t = ot.dist(X_t, X_t)
+
+ emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1)
+ sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) -
+ ot.sinkhorn2(b, b, M_t, 1))
+
+ # check constratints
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn