From 42a501c5d839c010bbfa3a4440b43cb4f9775fc7 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 11 Mar 2019 10:39:03 +0100 Subject: add test sinkhorn+log --- test/test_bregman.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 14edaf5..90eaf27 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -81,6 +81,31 @@ def test_sinkhorn_variants(): print(G0, G_green) +def test_sinkhorn_variants_log(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True) + Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Ges, loges = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True) + G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, Gerr) + np.testing.assert_allclose(G0, G_green, atol=1e-5) + print(G0, G_green) + + def test_bary(): n_bins = 100 # nb bins -- cgit v1.2.3 From a2545b5a503c95c9bf07948929b77e9c3f4f28d3 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Fri, 29 Mar 2019 12:41:43 +0100 Subject: add empirical sinkhorn and sikhorn divergence functions --- README.md | 2 + examples/plot_OT_2D_samples.py | 26 ++++ ot/bregman.py | 269 +++++++++++++++++++++++++++++++++++++++++ test/test_bregman.py | 57 +++++++++ 4 files changed, 354 insertions(+) (limited to 'test') diff --git a/README.md b/README.md index b068131..dbd93fc 100644 --- a/README.md +++ b/README.md @@ -230,3 +230,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. [22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 + +[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index bb952a0..63126ba 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples. """ # Author: Remi Flamary +# Kilian Fatras # # License: MIT License @@ -100,3 +101,28 @@ pl.legend(loc=0) pl.title('OT matrix Sinkhorn with samples') pl.show() + + +############################################################################## +# Emprirical Sinkhorn +# ---------------- + +#%% sinkhorn + +# reg term +lambd = 1e-3 + +Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) + +pl.figure(7) +pl.imshow(Ges, interpolation='nearest') +pl.title('OT matrix empirical sinkhorn') + +pl.figure(8) +ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix Sinkhorn from samples') + +pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index 013bc33..f1b18f8 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,6 +5,7 @@ Bregman projections for regularized OT # Author: Remi Flamary # Nicolas Courty +# Kilian Fatras # # License: MIT License @@ -1296,3 +1297,271 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, return np.sum(K0, axis=1), log else: return np.sum(K0, axis=1) + + +def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Solve the entropic regularization optimal transport problem and return the + OT matrix from empirical data + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False) + >>> print(emp_sinkhorn) + >>> [[4.99977301e-01 2.26989344e-05] + [2.26989344e-05 4.99977301e-01]] + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = ot.unif(np.shape(X_s)[0]) + if b is None: + b = ot.unif(np.shape(X_t)[0]) + M = ot.dist(X_s, X_t, metric=metric) + if log == False: + pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi + + if log == True: + pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + + +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Solve the entropic regularization optimal transport problem from empirical + data and return the OT loss + + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False) + >>> print(loss_sinkhorn) + >>> [4.53978687e-05] + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = ot.unif(np.shape(X_s)[0]) + if b is None: + b = ot.unif(np.shape(X_t)[0]) + + M = ot.dist(X_s, X_t, metric=metric) + if log == False: + sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss + + if log == True: + sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss, log + + +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Compute the sinkhorn divergence loss from empirical data + + The function solves the following optimization problem: + + .. math:: + S = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - + \min_\gamma_a <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) - + \min_\gamma_b <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + \gamma_a 1 = a + + \gamma_a^T 1= a + + \gamma_a\geq 0 + + \gamma_b 1 = b + + \gamma_b^T 1= b + + \gamma_b\geq 0 + where : + + - M (resp. :math:`M_a, M_b) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 4 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg) + >>> print(emp_sinkhorn_div) + >>> [2.99977435] + + + References + ---------- + + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + ''' + + sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) + return max(0, sinkhorn_div) diff --git a/test/test_bregman.py b/test/test_bregman.py index 90eaf27..b890df1 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1,6 +1,7 @@ """Tests for module bregman on OT with bregman projections """ # Author: Remi Flamary +# Kilian Fatras # # License: MIT License @@ -187,3 +188,59 @@ def test_unmix(): ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, log=True, verbose=True) + + +def test_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + M = ot.dist(X_s, X_t) + M_e = ot.dist(X_s, X_t, metric='euclidean') + + rng = np.random.RandomState(0) + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + + G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) + sinkhorn_e = ot.sinkhorn(a, b, M_e, 1) + + loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + +def test_empirical_sinkhorn_divergence(): + #Test sinkhorn divergence + n = 10 + a = ot.unif(n) + b = ot.unif(n) + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + M = ot.dist(X_s, X_t) + M_s = ot.dist(X_s, X_s) + M_t = ot.dist(X_t, X_t) + + emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1) + sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - + ot.sinkhorn2(b, b, M_t, 1)) + + # check constratints + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + np.testing.assert_allclose( + emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn -- cgit v1.2.3 From 9569f893defa8e712a4f3199770a0df745d4cfff Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Fri, 29 Mar 2019 13:06:01 +0100 Subject: fix pep8 --- ot/bregman.py | 29 +++++++++++++++-------------- test/test_bregman.py | 6 ++---- 2 files changed, 17 insertions(+), 18 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index f1b18f8..f6aa339 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1375,17 +1375,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = utils.unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = utils.unif(np.shape(X_t)[0]) + M = ot.dist(X_s, X_t, metric=metric) - if log == False: - pi = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi - if log == True: - pi, log = ot.sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -1464,18 +1465,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num ''' if a is None: - a = ot.unif(np.shape(X_s)[0]) + a = utils.unif(np.shape(X_s)[0]) if b is None: - b = ot.unif(np.shape(X_t)[0]) + b = utils.unif(np.shape(X_t)[0]) M = ot.dist(X_s, X_t, metric=metric) - if log == False: - sinkhorn_loss = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - return sinkhorn_loss - if log == True: - sinkhorn_loss, log = ot.sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): diff --git a/test/test_bregman.py b/test/test_bregman.py index b890df1..8b001a7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -195,13 +195,11 @@ def test_empirical_sinkhorn(): n = 100 a = ot.unif(n) b = ot.unif(n) - M = ot.dist(X_s, X_t) - M_e = ot.dist(X_s, X_t, metric='euclidean') - - rng = np.random.RandomState(0) X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_e = ot.dist(X_s, X_t, metric='euclidean') G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) -- cgit v1.2.3 From d754a645f9b4ef88d7e0aba1188fa83d7d58af1f Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Fri, 29 Mar 2019 13:24:54 +0100 Subject: typos PEP8 --- test/test_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 8b001a7..4aae6cb 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -233,7 +233,7 @@ def test_empirical_sinkhorn_divergence(): M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, 1) + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - ot.sinkhorn2(b, b, M_t, 1)) -- cgit v1.2.3 From 1ceb1a9cc96aad54e525c2021851b8639e2f3449 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Sun, 31 Mar 2019 12:14:54 +0200 Subject: fix metric test --- test/test_bregman.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 4aae6cb..0ebd546 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -199,13 +199,13 @@ def test_empirical_sinkhorn(): X_s = np.reshape(np.arange(n), (n, 1)) X_t = np.reshape(np.arange(0, n), (n, 1)) M = ot.dist(X_s, X_t) - M_e = ot.dist(X_s, X_t, metric='euclidean') + M_m = ot.dist(X_s, X_t, metric='minkowski') G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) - G_e = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) - sinkhorn_e = ot.sinkhorn(a, b, M_e, 1) + G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1) loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) @@ -216,9 +216,9 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose( sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( - sinkhorn_e.sum(1), G_e.sum(1), atol=1e-05) # metric euclidian + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian np.testing.assert_allclose( - sinkhorn_e.sum(0), G_e.sum(0), atol=1e-05) # metric euclidian + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) -- cgit v1.2.3 From 780bdfee3c622698dc9b18a02fa06381314aa56d Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 4 Apr 2019 13:45:33 +0200 Subject: fix log in sinkhorn div and add log tests --- ot/bregman.py | 26 ++++++++++++++++++++++---- test/test_bregman.py | 12 +++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index 47554fb..7acfcf1 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' + if log: + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) - return max(0, sinkhorn_div) + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + + log = {} + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab + log['sinkhorn_loss_a'] = sinkhorn_loss_a + log['sinkhorn_loss_b'] = sinkhorn_loss_b + log['log_sinkhorn_ab'] = log_ab + log['log_sinkhorn_a'] = log_a + log['log_sinkhorn_b'] = log_b + + return max(0, sinkhorn_div), log + else: + sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - + 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) + return max(0, sinkhorn_div) diff --git a/test/test_bregman.py b/test/test_bregman.py index 0ebd546..68d3595 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -204,6 +204,9 @@ def test_empirical_sinkhorn(): G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1) sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski') sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) @@ -215,6 +218,10 @@ def test_empirical_sinkhorn(): sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian np.testing.assert_allclose( sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log np.testing.assert_allclose( sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian np.testing.assert_allclose( @@ -237,8 +244,11 @@ def test_empirical_sinkhorn_divergence(): sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - ot.sinkhorn2(b, b, M_t, 1)) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True) + sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + # check constratints np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( - emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn + emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn -- cgit v1.2.3 From 69186a6f4259d32fecac370f59efe16e2e460d04 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 4 Apr 2019 13:58:50 +0200 Subject: fix test sinkhorn div --- ot/bregman.py | 11 ++++++++--- test/test_bregman.py | 8 +++++--- 2 files changed, 13 insertions(+), 6 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index 7acfcf1..dc43834 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli log['log_sinkhorn_b'] = log_b return max(0, sinkhorn_div), log + else: - sinkhorn_div = (empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - 1 / 2 * empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - - 1 / 2 * empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) diff --git a/test/test_bregman.py b/test/test_bregman.py index 68d3595..58700e2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -241,11 +241,13 @@ def test_empirical_sinkhorn_divergence(): M_t = ot.dist(X_t, X_t) emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) - sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - - ot.sinkhorn2(b, b, M_t, 1)) + sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True) - sinkhorn_div_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1) + sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1) + sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1) + sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constratints np.testing.assert_allclose( -- cgit v1.2.3 From 782d9b1ae9d8c0b01e32c2af925ac9b7efa42a70 Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 4 Apr 2019 14:11:36 +0200 Subject: fix test sinkhorn div --- test/test_bregman.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 58700e2..d5482f7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -243,11 +243,11 @@ def test_empirical_sinkhorn_divergence(): emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1)) - emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 0.1, log=True) - sink_div_log, log_s = ot.sinkhorn2(a, b, M, 1) - sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1) - sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1) - sink_div_log = sink_div_log - 1 / 2 * (sink_div_log_a + sink_div_log_b) + emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True) + sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True) + sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True) + sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True) + sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b) # check constratints np.testing.assert_allclose( -- cgit v1.2.3 From 17fa4f9a8cf7ffd1a58853b4091cee0238a1100b Mon Sep 17 00:00:00 2001 From: Kilian Fatras Date: Thu, 4 Apr 2019 14:16:52 +0200 Subject: fix test sinkhorn div --- test/test_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index d5482f7..7f4972c 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -253,4 +253,4 @@ def test_empirical_sinkhorn_divergence(): np.testing.assert_allclose( emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( - emp_sinkhorn_div_log, sinkhorn_div_log, atol=1e-05) # cf conv emp sinkhorn + emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn -- cgit v1.2.3 From 6484c9ea301fc15ae53b4afe134941909f581ffe Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:11:48 +0200 Subject: Tests + contributions --- README.md | 1 + ot/gromov.py | 12 ++++++--- test/test_gromov.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_optim.py | 5 ++++ 4 files changed, 89 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index 9951773..9692344 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,7 @@ The contributors to this library are: * Erwan Vautier (Gromov-Wasserstein) * [Kilian Fatras](https://kilianfatras.github.io/) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) +* [Vayer Titouan](https://tvayer.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/ot/gromov.py b/ot/gromov.py index ad68a1c..297b194 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + + class UndefinedParameter(Exception): + pass + S = len(Cs) d = Ys[0].shape[1] #dimension on the node features if p is None: @@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_structure: if init_C is None: - C=Cs[0] + raise UndefinedParameter('If C is fixed it must be initialized') else: C=init_C else: @@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_features: if init_X is None: - X=Ys[0] + raise UndefinedParameter('If X is fixed it must be initialized') else : X= init_X else: @@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature # Cs is ns,ns # p is N,1 # ps is ns,1 - + T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns log_['Ts_iter'].append(T) - err_feature = np.linalg.norm(X - Xprev.reshape(d,N)) + err_feature = np.linalg.norm(X - Xprev.reshape(N,d)) err_structure = np.linalg.norm(C - Cprev) if log: diff --git a/test/test_gromov.py b/test/test_gromov.py index fb86274..07cd874 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + +def test_fgw(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0],2) + yt= ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M=ot.dist(ys,yt) + M/=M.max() + + G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence fgw + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence fgw + + +def test_fgw_barycenter(): + + ns = 50 + nt = 60 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + + ys = np.random.randn(Xs.shape[0],2) + yt= np.random.randn(Xt.shape[0],2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + + n_samples = 3 + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + xalea = np.random.randn(n_samples, 2) + init_C = ot.dist(xalea, xalea) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, + fixed_structure=True,init_C=init_C,fixed_features=False, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + init_X=np.random.randn(n_samples,ys.shape[1]) + + X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, + fixed_structure=False,fixed_features=True, init_X=init_X, + p=ot.unif(n_samples),loss_fun='square_loss', + max_iter=100, tol=1e-3) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_optim.py b/test/test_optim.py index dfefe59..1188ef6 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,3 +65,8 @@ def test_generalized_conditional_gradient(): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + +def test_solve_1d_linesearch_quad_funct(): + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) -- cgit v1.2.3 From fa989062c17f87bd96aa58ad764fd3791ea11e22 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:00:50 +0200 Subject: Reame +pep8 --- README.md | 14 ++++ examples/plot_barycenter_fgw.py | 150 ++++++++++++++++++++-------------------- examples/plot_fgw.py | 138 ++++++++++++++++++------------------ test/test_gromov.py | 53 +++++++------- test/test_optim.py | 9 +-- 5 files changed, 190 insertions(+), 174 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index fd27f9d..b6b215c 100644 --- a/README.md +++ b/README.md @@ -222,3 +222,17 @@ You can also post bug reports and feature requests in Github issues. Make sure t [16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) + +[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning + +[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. + +[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 + +[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + +[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py index f416629..9eea036 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/plot_barycenter_fgw.py @@ -30,10 +30,11 @@ from matplotlib import cm from ot.gromov import fgw_barycenters #%% Graph functions -def find_thresh(C,inf=0.5,sup=3,step=10): + +def find_thresh(C, inf=0.5, sup=3, step=10): """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected - Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. - The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix and the original matrix. Parameters ---------- @@ -43,21 +44,22 @@ def find_thresh(C,inf=0.5,sup=3,step=10): The beginning of the linesearch sup : float The end of the linesearch - step : integer - Number of thresholds tested + step : integer + Number of thresholds tested """ - dist=[] - search=np.linspace(inf,sup,step) + dist = [] + search = np.linspace(inf, sup, step) for thresh in search: - Cprime=sp_to_adjency(C,0,thresh) - SC=shortest_path(Cprime,method='D') - SC[SC==float('inf')]=100 - dist.append(np.linalg.norm(SC-C)) - return search[np.argmin(dist)],dist - -def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): - """ Thresholds the structure matrix in order to compute an adjency matrix. - All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Cprime = sp_to_adjency(C, 0, thresh) + SC = shortest_path(Cprime, method='D') + SC[SC == float('inf')] = 100 + dist.append(np.linalg.norm(SC - C)) + return search[np.argmin(dist)], dist + + +def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 Parameters ---------- C : ndarray, shape (n_nodes,n_nodes) @@ -71,102 +73,100 @@ def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): C : ndarray, shape (n_nodes,n_nodes) The threshold matrix. Each element is in {0,1} """ - H=np.zeros_like(C) - np.fill_diagonal(H,np.diagonal(C)) - C=C-H - C=np.minimum(np.maximum(C,threshinf),threshsup) - C[C==threshsup]=0 - C[C!=0]=1 - - return C - -def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): + H = np.zeros_like(C) + np.fill_diagonal(H, np.diagonal(C)) + C = C - H + C = np.minimum(np.maximum(C, threshinf), threshsup) + C[C == threshsup] = 0 + C[C != 0] = 1 + + return C + + +def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): """ Create a noisy circular graph """ - g=nx.Graph() + g = nx.Graph() g.add_nodes_from(list(range(N))) for i in range(N): - noise=float(np.random.normal(mu,sigma,1)) + noise = float(np.random.normal(mu, sigma, 1)) if with_noise: - g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) + g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) else: - g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) - g.add_edge(i,i+1) + g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) + g.add_edge(i, i + 1) if structure_noise: - randomint=np.random.randint(0,p) - if randomint==0: - if i<=N-3: - g.add_edge(i,i+2) - if i==N-2: - g.add_edge(i,0) - if i==N-1: - g.add_edge(i,1) - g.add_edge(N,0) - noise=float(np.random.normal(mu,sigma,1)) + randomint = np.random.randint(0, p) + if randomint == 0: + if i <= N - 3: + g.add_edge(i, i + 2) + if i == N - 2: + g.add_edge(i, 0) + if i == N - 1: + g.add_edge(i, 1) + g.add_edge(N, 0) + noise = float(np.random.normal(mu, sigma, 1)) if with_noise: - g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) + g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) else: - g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) + g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) return g -def graph_colors(nx_graph,vmin=0,vmax=7): - cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) - cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') cpick.set_array([]) val_map = {} - for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): - val_map[k]=cpick.to_rgba(v) - colors=[] + for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): + val_map[k] = cpick.to_rgba(v) + colors = [] for node in nx_graph.nodes(): colors.append(val_map[node]) return colors - + #%% create dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. + np.random.seed(30) -X0=[] +X0 = [] for k in range(9): - X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) - + X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) + #%% Plot dataset -plt.figure(figsize=(8,10)) +plt.figure(figsize=(8, 10)) for i in range(len(X0)): - plt.subplot(3,3,i+1) - g=X0[i] - pos=nx.kamada_kawai_layout(g) - nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) -plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) + plt.subplot(3, 3, i + 1) + g = X0[i] + pos = nx.kamada_kawai_layout(g) + nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) plt.show() - #%% # We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances -Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] -ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] -Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] -lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() -sizebary=15 # we choose a barycenter with 15 nodes +Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] +Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] +lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() +sizebary = 15 # we choose a barycenter with 15 nodes #%% -A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) +A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) #%% -bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) +bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) for i in range(len(A.ravel())): - bary.add_node(i,attr_name=float(A.ravel()[i])) - + bary.add_node(i, attr_name=float(A.ravel()[i])) + #%% pos = nx.kamada_kawai_layout(bary) -nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) -plt.suptitle('Barycenter',fontsize=20) +nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) +plt.suptitle('Barycenter', fontsize=20) plt.show() - - - - diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index bfa7fb4..ae3c487 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -20,132 +20,132 @@ This example illustrates the computation of FGW for 1D measures[18]. import matplotlib.pyplot as pl import numpy as np import ot -from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein +from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein #%% parameters -# We create two 1D random measures -n=20 -n2=30 -sig=1 -sig2=0.1 +# We create two 1D random measures +n = 20 +n2 = 30 +sig = 1 +sig2 = 0.1 np.random.seed(0) -phi=np.arange(n)[:,None] -xs=phi+sig*np.random.randn(n,1) -ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1) +phi = np.arange(n)[:, None] +xs = phi + sig * np.random.randn(n, 1) +ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) -phi2=np.arange(n2)[:,None] -xt=phi2+sig*np.random.randn(n2,1) -yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1) -yt= yt[::-1,:] +phi2 = np.arange(n2)[:, None] +xt = phi2 + sig * np.random.randn(n2, 1) +yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) +yt = yt[::-1, :] -p=ot.unif(n) -q=ot.unif(n2) +p = ot.unif(n) +q = ot.unif(n2) #%% plot the distributions pl.close(10) -pl.figure(10,(7,7)) +pl.figure(10, (7, 7)) -pl.subplot(2,1,1) +pl.subplot(2, 1, 1) -pl.scatter(ys,xs,c=phi,s=70) -pl.ylabel('Feature value a',fontsize=20) -pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1) +pl.scatter(ys, xs, c=phi, s=70) +pl.ylabel('Feature value a', fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) pl.xticks(()) pl.yticks(()) -pl.subplot(2,1,2) -pl.scatter(yt,xt,c=phi2,s=70) -pl.xlabel('coordinates x/y',fontsize=25) -pl.ylabel('Feature value b',fontsize=20) -pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1) +pl.subplot(2, 1, 2) +pl.scatter(yt, xt, c=phi2, s=70) +pl.xlabel('coordinates x/y', fontsize=25) +pl.ylabel('Feature value b', fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) pl.yticks(()) pl.tight_layout() pl.show() #%% Structure matrices and across-features distance matrix -C1=ot.dist(xs) -C2=ot.dist(xt).T -M=ot.dist(ys,yt) -w1=ot.unif(C1.shape[0]) -w2=ot.unif(C2.shape[0]) -Got=ot.emd([],[],M) +C1 = ot.dist(xs) +C2 = ot.dist(xt).T +M = ot.dist(ys, yt) +w1 = ot.unif(C1.shape[0]) +w2 = ot.unif(C2.shape[0]) +Got = ot.emd([], [], M) #%% -cmap='Reds' +cmap = 'Reds' pl.close(10) -pl.figure(10,(5,5)) -fs=15 -l_x=[0,5,10,15] -l_y=[0,5,10,15,20,25] +pl.figure(10, (5, 5)) +fs = 15 +l_x = [0, 5, 10, 15] +l_y = [0, 5, 10, 15, 20, 25] gs = pl.GridSpec(5, 5) -ax1=pl.subplot(gs[3:,:2]) +ax1 = pl.subplot(gs[3:, :2]) -pl.imshow(C1,cmap=cmap,interpolation='nearest') -pl.title("$C_1$",fontsize=fs) -pl.xlabel("$k$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.imshow(C1, cmap=cmap, interpolation='nearest') +pl.title("$C_1$", fontsize=fs) +pl.xlabel("$k$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.xticks(l_x) pl.yticks(l_x) -ax2=pl.subplot(gs[:3,2:]) +ax2 = pl.subplot(gs[:3, 2:]) -pl.imshow(C2,cmap=cmap,interpolation='nearest') -pl.title("$C_2$",fontsize=fs) -pl.ylabel("$l$",fontsize=fs) +pl.imshow(C2, cmap=cmap, interpolation='nearest') +pl.title("$C_2$", fontsize=fs) +pl.ylabel("$l$", fontsize=fs) #pl.ylabel("$l$",fontsize=fs) pl.xticks(()) pl.yticks(l_y) ax2.set_aspect('auto') -ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1) -pl.imshow(M,cmap=cmap,interpolation='nearest') +ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) +pl.imshow(M, cmap=cmap, interpolation='nearest') pl.yticks(l_x) pl.xticks(l_y) -pl.ylabel("$i$",fontsize=fs) -pl.title("$M_{AB}$",fontsize=fs) -pl.xlabel("$j$",fontsize=fs) +pl.ylabel("$i$", fontsize=fs) +pl.title("$M_{AB}$", fontsize=fs) +pl.xlabel("$j$", fontsize=fs) pl.tight_layout() ax3.set_aspect('auto') pl.show() #%% Computing FGW and GW -alpha=1e-3 - +alpha = 1e-3 + ot.tic() -Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True) +Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) ot.toc() -#%reload_ext WGW -Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True) - +#%reload_ext WGW +Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) + #%% visu OT matrix -cmap='Blues' -fs=15 -pl.figure(2,(13,5)) +cmap = 'Blues' +fs = 15 +pl.figure(2, (13, 5)) pl.clf() -pl.subplot(1,3,1) -pl.imshow(Got,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 1) +pl.imshow(Got, cmap=cmap, interpolation='nearest') #pl.xlabel("$y$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.xticks(()) pl.title('Wasserstein ($M$ only)') -pl.subplot(1,3,2) -pl.imshow(Gg,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 2) +pl.imshow(Gg, cmap=cmap, interpolation='nearest') pl.title('Gromov ($C_1,C_2$ only)') pl.xticks(()) -pl.subplot(1,3,3) -pl.imshow(Gwg,cmap=cmap,interpolation='nearest') +pl.subplot(1, 3, 3) +pl.imshow(Gwg, cmap=cmap, interpolation='nearest') pl.title('FGW ($M+C_1,C_2$)') -pl.xlabel("$j$",fontsize=fs) -pl.ylabel("$i$",fontsize=fs) +pl.xlabel("$j$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) pl.tight_layout() -pl.show() \ No newline at end of file +pl.show() diff --git a/test/test_gromov.py b/test/test_gromov.py index 43b63e1..cd180d4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -145,7 +145,8 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - + + def test_fgw(): n_samples = 50 # nb samples @@ -155,9 +156,9 @@ def test_fgw(): xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) xt = xs[::-1].copy() - - ys = np.random.randn(xs.shape[0],2) - yt= ys[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) @@ -167,11 +168,11 @@ def test_fgw(): C1 /= C1.max() C2 /= C2.max() - - M=ot.dist(ys,yt) - M/=M.max() - G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5) + M = ot.dist(ys, yt) + M /= M.max() + + G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) # check constratints np.testing.assert_allclose( @@ -187,36 +188,36 @@ def test_fgw_barycenter(): Xs, ys = ot.datasets.make_data_classif('3gauss', ns) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) - - ys = np.random.randn(Xs.shape[0],2) - yt= np.random.randn(Xt.shape[0],2) + + ys = np.random.randn(Xs.shape[0], 2) + yt = np.random.randn(Xt.shape[0], 2) C1 = ot.dist(Xs) C2 = ot.dist(Xt) n_samples = 3 - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5, - fixed_structure=True,init_C=init_C,fixed_features=False, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) - - init_X=np.random.randn(n_samples,ys.shape[1]) - X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5, - fixed_structure=False,fixed_features=True, init_X=init_X, - p=ot.unif(n_samples),loss_fun='square_loss', - max_iter=100, tol=1e-3) + init_X = np.random.randn(n_samples, ys.shape[1]) + + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_optim.py b/test/test_optim.py index 1188ef6..e7ba32a 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,8 +65,9 @@ def test_generalized_conditional_gradient(): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) - + + def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1) -- cgit v1.2.3 From e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:05:38 +0200 Subject: code review1 --- examples/plot_barycenter_fgw.py | 30 +++++++---- examples/plot_fgw.py | 32 ++++++++++-- ot/gromov.py | 108 +++++++++++++++++++++++++++++++++++----- ot/optim.py | 31 ++++++------ test/test_gromov.py | 57 ++++++++++++++++----- 5 files changed, 204 insertions(+), 54 deletions(-) (limited to 'test') diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py index 9eea036..e4be447 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/plot_barycenter_fgw.py @@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7): colors.append(val_map[node]) return colors -#%% create dataset +############################################################################## +# Generate data +# ------------- + +#%% circular dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. @@ -135,7 +139,11 @@ X0 = [] for k in range(9): X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) -#%% Plot dataset +############################################################################## +# Plot data +# --------- + +#%% Plot graphs plt.figure(figsize=(8, 10)) for i in range(len(X0)): @@ -146,9 +154,11 @@ for i in range(len(X0)): plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) plt.show() +############################################################################## +# Barycenter computation +# ---------------------- -#%% -# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] @@ -156,14 +166,16 @@ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]) lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() sizebary = 15 # we choose a barycenter with 15 nodes -#%% - A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) -#%% +############################################################################## +# Plot Barycenter +# ------------------------- + +#%% Create the barycenter bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) -for i in range(len(A.ravel())): - bary.add_node(i, attr_name=float(A.ravel()[i])) +for i, v in enumerate(A.ravel()): + bary.add_node(i, attr_name=v) #%% pos = nx.kamada_kawai_layout(bary) diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index ae3c487..43efc94 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -22,12 +22,16 @@ import numpy as np import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein +############################################################################## +# Generate data +# --------- + #%% parameters # We create two 1D random measures -n = 20 -n2 = 30 -sig = 1 -sig2 = 0.1 +n = 20 # number of points in the first distribution +n2 = 30 # number of points in the second distribution +sig = 1 # std of first distribution +sig2 = 0.1 # std of second distribution np.random.seed(0) @@ -43,6 +47,10 @@ yt = yt[::-1, :] p = ot.unif(n) q = ot.unif(n2) +############################################################################## +# Plot data +# --------- + #%% plot the distributions pl.close(10) @@ -64,15 +72,22 @@ pl.yticks(()) pl.tight_layout() pl.show() +############################################################################## +# Create structure matrices and across-feature distance matrix +# --------- #%% Structure matrices and across-features distance matrix C1 = ot.dist(xs) -C2 = ot.dist(xt).T +C2 = ot.dist(xt) M = ot.dist(ys, yt) w1 = ot.unif(C1.shape[0]) w2 = ot.unif(C2.shape[0]) Got = ot.emd([], [], M) +############################################################################## +# Plot matrices +# --------- + #%% cmap = 'Reds' pl.close(10) @@ -112,6 +127,9 @@ pl.tight_layout() ax3.set_aspect('auto') pl.show() +############################################################################## +# Compute FGW/GW +# --------- #%% Computing FGW and GW alpha = 1e-3 @@ -123,6 +141,10 @@ ot.toc() #%reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) +############################################################################## +# Visualize transport matrices +# --------- + #%% visu OT matrix cmap = 'Blues' fs = 15 diff --git a/ot/gromov.py b/ot/gromov.py index 5a57dc8..53349b7 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -10,6 +10,7 @@ Gromov-Wasserstein transport method # Nicolas Courty # Rémi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -351,9 +352,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ - Computes the FGW distance between two graphs see [3] + Computes the FGW transport between two graphs see [24] .. math:: \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} s.t. \gamma 1 = p @@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, distribution in the source space q : ndarray, shape (nt,) distribution in the target space - loss_fun : string,optionnal + loss_fun : string,optional loss function used for the solver max_iter : int, optional Max number of iterations @@ -416,7 +417,86 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def df(G): return gwggrad(constC, hC1, hC2, G) - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + if log: + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + log['fgw_dist'] = log['loss'][::-1][0] + return res, log + else: + return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): + """ + Computes the FGW distance between two graphs see [24] + .. math:: + \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p + \gamma^T 1= q + \gamma\geq 0 + where : + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + - L is a loss function to account for the misfit between the similarity matrices + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters + ---------- + M : ndarray, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix respresentative of the structure in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix espresentative of the structure in the target space + p : ndarray, shape (ns,) + distribution in the source space + q : ndarray, shape (nt,) + distribution in the target space + loss_fun : string,optional + loss function used for the solver + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + **kwargs : dict + parameters can be directly pased to the ot.optim.cg solver + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + if log: + log['fgw_dist'] = log['loss'][::-1][0] + log['T'] = res + return log['fgw_dist'], log + else: + return log['fgw_dist'] def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): @@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=True, init_C=None, init_X=None): + verbose=False, log=False, init_C=None, init_X=None): """ Compute the fgw barycenter as presented eq (5) in [24]. ---------- @@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Barycenters' features C : ndarray, shape (N,N) Barycenters' structure matrix - log_: + log_: dictionary + Only returned when log=True T : list of (N,ns) transport matrices Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) References @@ -1015,14 +1096,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - - log_['Ts_iter'].append(T) err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) err_structure = np.linalg.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) if verbose: if cpt % 200 == 0: @@ -1032,11 +1112,15 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T'] = T # from target to Ys - log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms # Ms are N,ns - return X, C, log_ + if log: + return X, C, log_ + else: + return X, C def update_sructure_matrix(p, lambdas, T, Cs): diff --git a/ot/optim.py b/ot/optim.py index 7d103e2..4d428d9 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -5,6 +5,7 @@ Optimization algorithms for OT # Author: Remi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost f_val : float Value of the cost at G - armijo : bool, optionnal + armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. - C1 : ndarray (ns,ns), optionnal + C1 : ndarray (ns,ns), optional Structure matrix in the source domain. Only used when armijo=False - C2 : ndarray (nt,nt), optionnal + C2 : ndarray (nt,nt), optional Structure matrix in the target domain. Only used when armijo=False - reg : float, optionnal + reg : float, optional Regularization parameter. Only used when armijo=False Gc : ndarray (ns,nt) Optimal map found by linearization in the FW algorithm. Only used when armijo=False constC : ndarray (ns,nt) Constant for the gromov cost. See [24]. Only used when armijo=False - M : ndarray (ns,nt), optionnal + M : ndarray (ns,nt), optional Cost matrix between the features. Only used when armijo=False Returns ------- @@ -223,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -261,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: @@ -363,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -402,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: diff --git a/test/test_gromov.py b/test/test_gromov.py index cd180d4..ec85abf 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -2,6 +2,7 @@ # Author: Erwan Vautier # Nicolas Courty +# Titouan Vayer # # License: MIT License @@ -10,6 +11,8 @@ import ot def test_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -36,6 +39,11 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) G = log['T'] @@ -50,6 +58,8 @@ def test_gromov(): def test_entropic_gromov(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -92,6 +102,7 @@ def test_entropic_gromov(): def test_gromov_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -120,7 +131,7 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter(): - + np.random.seed(42) ns = 50 nt = 60 @@ -148,6 +159,8 @@ def test_gromov_entropic_barycenter(): def test_fgw(): + np.random.seed(42) + n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -180,8 +193,26 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw + Id = (1 / n_samples) * np.eye(n_samples, n_samples) + + np.testing.assert_allclose( + G, np.flipud(Id), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + + G = log['T'] + + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + def test_fgw_barycenter(): + np.random.seed(42) ns = 50 nt = 60 @@ -196,28 +227,28 @@ def test_fgw_barycenter(): C2 = ot.dist(Xt) n_samples = 3 - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, + fixed_structure=True, init_C=init_C, fixed_features=False, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3 From 28059eb5e0aad715823ee4f6509d6a9e3d6e5db0 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:11:41 +0200 Subject: py2 error --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index ec85abf..3ca184b 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -13,7 +13,7 @@ import ot def test_gromov(): np.random.seed(42) - n_samples = 50 # nb samples + n_samples = 50.0 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -161,7 +161,7 @@ def test_gromov_entropic_barycenter(): def test_fgw(): np.random.seed(42) - n_samples = 50 # nb samples + n_samples = 50.0 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) -- cgit v1.2.3 From 63093cef7af3350228251aa930872c6f30789432 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:19:13 +0200 Subject: n_samples float --- test/test_gromov.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 3ca184b..d7a12f3 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -13,7 +13,7 @@ import ot def test_gromov(): np.random.seed(42) - n_samples = 50.0 # nb samples + n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / n_samples) * np.eye(n_samples, n_samples) + Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -161,7 +161,7 @@ def test_gromov_entropic_barycenter(): def test_fgw(): np.random.seed(42) - n_samples = 50.0 # nb samples + n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / n_samples) * np.eye(n_samples, n_samples) + Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov -- cgit v1.2.3 From 9bb7d40b563f42bf2875efca860bf0c579307161 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:52:20 +0200 Subject: .0 --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index d7a12f3..b7ede95 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) + Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / float(n_samples)) * np.eye(n_samples, n_samples) + Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov -- cgit v1.2.3 From 89a2e0aee4353a051d924de0457f8976c26fa5d7 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 18:02:27 +0200 Subject: pep8 + err --- test/test_gromov.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index b7ede95..f218b74 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -39,7 +39,7 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov - Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) @@ -193,7 +193,7 @@ def test_fgw(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence fgw - Id = (1 / 1.0*n_samples) * np.eye(n_samples, n_samples) + Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( G, np.flipud(Id), atol=1e-04) # cf convergence gromov -- cgit v1.2.3 From ad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 4 Jun 2019 10:32:30 +0200 Subject: changes forgotten coments --- ot/gromov.py | 26 +++----------------------- ot/optim.py | 32 ++++++++++++++++---------------- ot/utils.py | 8 ++++++++ test/test_optim.py | 6 +++--- 4 files changed, 30 insertions(+), 42 deletions(-) (limited to 'test') diff --git a/ot/gromov.py b/ot/gromov.py index 53349b7..ca96b31 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -17,7 +17,7 @@ import numpy as np from .bregman import sinkhorn -from .utils import dist +from .utils import dist, UndefinedParameter from .optim import cg @@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019. """ - class UndefinedParameter(Exception): - pass - S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: @@ -1049,10 +1046,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T = [np.outer(p, q) for q in ps] - # X is N,d - # Ys is ns,d - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] - # Ms is N,ns + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns cpt = 0 err_feature = 1 @@ -1072,27 +1066,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - # X must be N,d - # Ys must be ns,d Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': - # T must be ns,N - # Cs must be ns,ns - # p must be N,1 T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - # Ys must be d,ns - # Ts must be N,ns - # p must be N,1 - # Ms is N,ns - # C is N,N - # Cs is ns,ns - # p is N,1 - # ps is ns,1 - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns @@ -1115,7 +1095,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ if log: log_['T'] = T # from target to Ys log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + log_['Ms'] = Ms if log: return X, C, log_ diff --git a/ot/optim.py b/ot/optim.py index 4d428d9..f94aceb 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,8 +73,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 -def do_linesearch(cost, G, deltaG, Mi, f_val, - armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters @@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. C1 : ndarray (ns,ns), optional - Structure matrix in the source domain. Only used when armijo=False + Structure matrix in the source domain. Only used and necessary when armijo=False C2 : ndarray (nt,nt), optional - Structure matrix in the target domain. Only used when armijo=False + Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used when armijo=False + Regularization parameter. Only used and necessary when armijo=False Gc : ndarray (ns,nt) - Optimal map found by linearization in the FW algorithm. Only used when armijo=False + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used when armijo=False + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False M : ndarray (ns,nt), optional - Cost matrix between the features. Only used when armijo=False + Cost matrix between the features. Only used and necessary when armijo=False Returns ------- alpha : float @@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) c = cost(G) - alpha = solve_1d_linesearch_quad_funct(a, b, c) + alpha = solve_1d_linesearch_quad(a, b, c) fc = None f_val = cost(G + alpha * deltaG) @@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True - kwargs : dict + **kwargs : dict Parameters for linesearch Returns @@ -244,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -254,7 +254,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -395,7 +395,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, abs_delta_fval = abs(f_val - old_fval) relative_delta_fval = abs_delta_fval / abs(f_val) - if relative_delta_fval < stopThr and abs_delta_fval < stopThr2: + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -413,11 +413,11 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, return G -def solve_1d_linesearch_quad_funct(a, b, c): +def solve_1d_linesearch_quad(a, b, c): """ - Solve on 0,1 the following problem: + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: .. math:: - \min f(x)=a*x^{2}+b*x+c + \argmin f(x)=a*x^{2}+b*x+c Parameters ---------- diff --git a/ot/utils.py b/ot/utils.py index bb21b38..efd1288 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -487,3 +487,11 @@ class BaseEstimator(object): (key, self.__class__.__name__)) setattr(self, key, value) return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass diff --git a/test/test_optim.py b/test/test_optim.py index e7ba32a..ae31e1f 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -68,6 +68,6 @@ def test_generalized_conditional_gradient(): def test_solve_1d_linesearch_quad_funct(): - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0) - np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1) -- cgit v1.2.3 From 788a6506c9bf3b862a9652d74f65f8d07851e653 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 4 Jun 2019 11:34:46 +0200 Subject: seed --- test/test_gromov.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index f218b74..70fa83f 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -11,14 +11,12 @@ import ot def test_gromov(): - np.random.seed(42) - n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) xt = xs[::-1].copy() @@ -58,14 +56,12 @@ def test_gromov(): def test_entropic_gromov(): - np.random.seed(42) - n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) xt = xs[::-1].copy() @@ -102,13 +98,11 @@ def test_entropic_gromov(): def test_gromov_barycenter(): - np.random.seed(42) - ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -131,12 +125,11 @@ def test_gromov_barycenter(): def test_gromov_entropic_barycenter(): - np.random.seed(42) ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) C1 = ot.dist(Xs) C2 = ot.dist(Xt) @@ -159,14 +152,13 @@ def test_gromov_entropic_barycenter(): def test_fgw(): - np.random.seed(42) n_samples = 50 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) - xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) xt = xs[::-1].copy() @@ -217,8 +209,8 @@ def test_fgw_barycenter(): ns = 50 nt = 60 - Xs, ys = ot.datasets.make_data_classif('3gauss', ns) - Xt, yt = ot.datasets.make_data_classif('3gauss2', nt) + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) ys = np.random.randn(Xs.shape[0], 2) yt = np.random.randn(Xt.shape[0], 2) -- cgit v1.2.3 From 28b549ef3ef93c01462cd811d6e55c36ae5a76a2 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 15:50:25 +0200 Subject: add test and example of UOT --- examples/plot_UOT_1D.py | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ test/test_unbalanced.py | 36 +++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 examples/plot_UOT_1D.py create mode 100644 test/test_unbalanced.py (limited to 'test') diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py new file mode 100644 index 0000000..1b1dd9c --- /dev/null +++ b/examples/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +==================== +1D Unbalanced optimal transport +==================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +#%% Sinkhorn + +lambd = 0.1 +alpha = 1. +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py new file mode 100644 index 0000000..863b6f3 --- /dev/null +++ b/test/test_unbalanced.py @@ -0,0 +1,36 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import ot + + +def test_unbalanced(): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, + stopThr=1e-10, log=True) + + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) -- cgit v1.2.3 From 11381a7ecc79ef719ee9107167c3adc22b5a3f59 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:06:32 +0200 Subject: integrate comments of jmassich --- examples/plot_UOT_1D.py | 6 +++--- ot/unbalanced.py | 54 +++++++++++++++---------------------------------- test/test_unbalanced.py | 9 +++++++-- 3 files changed, 26 insertions(+), 43 deletions(-) (limited to 'test') diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py index 1b1dd9c..59b7e77 100644 --- a/examples/plot_UOT_1D.py +++ b/examples/plot_UOT_1D.py @@ -66,9 +66,9 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% Sinkhorn -lambd = 0.1 -alpha = 1. -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True) +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 8bd02eb..f4208b5 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -6,6 +6,7 @@ Regularized Unbalanced OT # Author: Hicham Janati # License: MIT License +import warnings import numpy as np # from .utils import unif, dist @@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -85,15 +86,14 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -101,17 +101,8 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method. Falling back to classic Sinkhorn Knopp') + warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, @@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -196,18 +187,13 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 - - + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ @@ -215,17 +201,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_stabilized': - # def sink(): - # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) - # elif method.lower() == 'sinkhorn_epsilon_scaling': - # def sink(): - # return sinkhorn_epsilon_scaling( - # a, b, M, reg, numItermax=numItermax, - # stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') + warnings.warn('Unknown method using classic Sinkhorn Knopp') def sink(): return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) @@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - a and b are source and target weights - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters @@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- @@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration', cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 863b6f3..e37498f 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -6,15 +6,19 @@ import numpy as np import ot +import pytest -def test_unbalanced(): +@pytest.mark.parametrize("metric", ["sinkhorn"]) +def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) x = rng.randn(n, 2) a = ot.utils.unif(n) + + # make dists unbalanced b = ot.utils.unif(n) * 1.5 M = ot.dist(x, x) @@ -23,7 +27,8 @@ def test_unbalanced(): K = np.exp(- M / epsilon) G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, log=True) + stopThr=1e-10, method=method, + log=True) # check fixed point equations fi = alpha / (alpha + epsilon) -- cgit v1.2.3 From 12ed1581225f70c7c8777b6ce31710453fda7f51 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:15:40 +0200 Subject: fix typo in test argument --- test/test_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e37498f..b4fa355 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,7 +9,7 @@ import ot import pytest -@pytest.mark.parametrize("metric", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 -- cgit v1.2.3 From 50bc90058940645a13e2f3e41129bdc97161dc63 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 12 Jun 2019 17:52:02 +0200 Subject: add unbalanced barycenters --- examples/plot_UOT_barycenter_1D.py | 164 +++++++++++++++++++++++++++++++++++++ ot/unbalanced.py | 118 ++++++++++++++++++++++++++ test/test_unbalanced.py | 30 +++++++ 3 files changed, 312 insertions(+) create mode 100644 examples/plot_UOT_barycenter_1D.py (limited to 'test') diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/plot_UOT_barycenter_1D.py new file mode 100644 index 0000000..8dfb84f --- /dev/null +++ b/examples/plot_UOT_barycenter_1D.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +=========================================================== +1D Wasserstein barycenter demo for Unbalanced distributions +=========================================================== + +This example illustrates the computation of regularized Wassersyein Barycenter +as proposed in [10] for Unbalanced inputs. + + +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + +""" + +# Author: Hicham Janati +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa +from matplotlib.collections import PolyCollection + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# make unbalanced dists +a2 *= 3. + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') +pl.tight_layout() + +############################################################################## +# Barycenter computation +# ---------------------- + +#%% non weighted barycenter computation + +weight = 0.5 # 0<=weight<=1 +weights = np.array([1 - weight, weight]) + +# l2bary +bary_l2 = A.dot(weights) + +# wasserstein +reg = 1e-3 +alpha = 1. + +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + +pl.figure(2) +pl.clf() +pl.subplot(2, 1, 1) +for i in range(n_distributions): + pl.plot(x, A[:, i]) +pl.title('Distributions') + +pl.subplot(2, 1, 2) +pl.plot(x, bary_l2, 'r', label='l2') +pl.plot(x, bary_wass, 'g', label='Wasserstein') +pl.legend() +pl.title('Barycenters') +pl.tight_layout() + +############################################################################## +# Barycentric interpolation +# ------------------------- + +#%% barycenter interpolation + +n_weight = 11 +weight_list = np.linspace(0, 1, n_weight) + + +B_l2 = np.zeros((n, n_weight)) + +B_wass = np.copy(B_l2) + +for i in range(0, n_weight): + weight = weight_list[i] + weights = np.array([1 - weight, weight]) + B_l2[:, i] = A.dot(weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + + +#%% plot interpolation + +pl.figure(3) + +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_l2[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with l2') +pl.tight_layout() + +pl.figure(4) +cmap = pl.cm.get_cmap('viridis') +verts = [] +zs = weight_list +for i, z in enumerate(zs): + ys = B_wass[:, i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel(r'$\alpha$') +ax.set_ylim3d(0, 1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max() * 1.01) +pl.title('Barycenter interpolation with Wasserstein') +pl.tight_layout() + +pl.show() diff --git a/ot/unbalanced.py b/ot/unbalanced.py index f4208b5..a30fc18 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :], log else: return u[:, None] * K * v[None, :] + + +def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + """Compute the entropic regularized unbalanced wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - alpha is the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (d,n) + n training distributions a_i of size d + M : np.ndarray (d,d) + loss matrix for OT + reg : float + Regularization term > 0 + alpha : float + Regularization term > 0 + weights : np.ndarray (n,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (d,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + + """ + p, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = np.exp(- M / reg) + + fi = alpha / (alpha + reg) + + v = np.ones((p, n_hists)) / p + u = np.ones((p, 1)) / p + + cpt = 0 + err = 1. + + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + u = (A / Kv) ** fi + Ktu = K.T.dot(u) + q = ((Ktu ** (1 - fi)).dot(weights)) + q = q ** (1 / (1 - fi)) + Q = q[:, None] + v = (Q / Ktu) ** fi + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration', cpt) + u = uprev + v = vprev + break + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ + np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if log: + log['niter'] = cpt + log['u'] = u + log['v'] = v + return q, log + else: + return q diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b4fa355..b39e457 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -39,3 +39,33 @@ def test_unbalanced_convergence(method): u_final, log["u"], atol=1e-05) np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + + +def test_unbalanced_barycenter(): + # test generalized sinkhorn for unbalanced OT barycenter + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + A = rng.rand(n, 2) + + # make dists unbalanced + A = A * np.array([1, 2])[None, :] + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, + stopThr=1e-10, + log=True) + + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (q[:, None] / K.T.dot(log["u"])) ** fi + u_final = (A / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) -- cgit v1.2.3 From 897982718a5fd81a9a591d80a7d50839399fc088 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 18 Jun 2019 16:40:06 +0200 Subject: fix func names + add more tests --- ot/__init__.py | 2 +- ot/bregman.py | 2 +- ot/unbalanced.py | 79 ++++++++++++++++++++++++++++++------------------- test/test_unbalanced.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 127 insertions(+), 35 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index 361be02..acb05e6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -25,7 +25,7 @@ from . import unbalanced # OT functions from .lp import emd, emd2 from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced +from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced from .da import sinkhorn_lpl1_mm # utils functions diff --git a/ot/bregman.py b/ot/bregman.py index 321712b..09716e6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -241,7 +241,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b.reshape((-1, 1)) + b = b[:, None] return sink() diff --git a/ot/unbalanced.py b/ot/unbalanced.py index a30fc18..97e2576 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1, 1) - array([0.26894142]) + >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) References @@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, See Also -------- - ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') return sink() -def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', + numItermax=1000, stopThr=1e-9, verbose=False, + log=False, **kwargs): u""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1., 1.) - array([ 0.26894142]) + >>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.) + array([0.31912866]) @@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - else: - warnings.warn('Unknown method using classic Sinkhorn Knopp') + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs) + return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method. Using classic Sinkhorn Knopp') b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: - b = b[None, :] + b = b[:, None] return sink() -def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, >>> a=[.5, .15] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.sinkhorn(a, b, M, 1., 1.) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) - + >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + array([[0.52761554, 0.22392482], + [0.10286295, 0.32257641]]) References ---------- @@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, if len(b) == 0: b = np.ones(n_b, dtype=np.float64) / n_b - assert n_a == len(a) and n_b == len(b) - if b.ndim > 1: + if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 @@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, n_hists)) / n_a + u = np.ones((n_a, 1)) / n_a v = np.ones((n_b, n_hists)) / n_b + a = a.reshape(n_a, 1) else: u = np.ones(n_a) / n_a v = np.ones(n_b) / n_b @@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000, cpt = 0 err = 1. + while (err > stopThr and cpt < numItermax): uprev = u vprev = v @@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index b39e457..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -29,7 +29,8 @@ def test_unbalanced_convergence(method): G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (b / K.T.dot(log["u"])) ** fi @@ -40,6 +41,44 @@ def test_unbalanced_convergence(method): np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + # check if sinkhorn_unbalanced2 returns the correct loss + np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn"]) +def test_unbalanced_multiple_inputs(method): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + K = np.exp(- M / epsilon) + + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, + stopThr=1e-10, method=method, + log=True) + # check fixed point equations + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi + + np.testing.assert_allclose( + u_final, log["u"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["v"], atol=1e-05) + + assert len(loss) == b.shape[1] + def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter @@ -59,7 +98,6 @@ def test_unbalanced_barycenter(): q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) - # check fixed point equations fi = alpha / (alpha + epsilon) v_final = (q[:, None] / K.T.dot(log["u"])) ** fi @@ -69,3 +107,40 @@ def test_unbalanced_barycenter(): u_final, log["u"], atol=1e-05) np.testing.assert_allclose( v_final, log["v"], atol=1e-05) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + alpha = 1. + for method in IMPLEMENTED_METHODS: + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.warns(UserWarning, match='not implemented'): + for method in set(TO_BE_IMPLEMENTED_METHODS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + method=method) -- cgit v1.2.3 From f63f34f8adb6943b6410f8b773b4b4d8f1c7b4ba Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Thu, 20 Jun 2019 14:29:56 +0200 Subject: EMD 1d without doc --- ot/__init__.py | 4 ++-- ot/lp/__init__.py | 43 ++++++++++++++++++++++++++++++++++++++----- ot/lp/emd_wrap.pyx | 35 +++++++++++++++++++++++++++++++++++ test/test_ot.py | 26 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 7 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index b74b924..5d5b700 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -22,7 +22,7 @@ from . import smooth from . import stochastic # OT functions -from .lp import emd, emd2 +from .lp import emd, emd2, emd_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm @@ -31,6 +31,6 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.5.1" -__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets', +__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 02cbd8c..49ded5b 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,12 +14,12 @@ import numpy as np from .import cvx # import compiled emd -from .emd_wrap import emd_c, check_result +from .emd_wrap import emd_c, check_result, emd_1d_sorted from ..utils import parmap from .cvx import barycenter from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx'] +__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted'] def emd(a, b, M, numItermax=100000, log=False): @@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use unifor distributions + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: @@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use unifor distributions + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: @@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None log_dict['displacement_square_norms'] = displacement_square_norms return X, log_dict else: - return X \ No newline at end of file + return X + + +def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False): + """Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + """ + assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \ + "with monodimensional data" + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + + # if empty array given then use uniform distributions + if len(a) == 0: + a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] + if len(b) == 0: + b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + + perm_a = np.argsort(x_a.reshape((-1, ))) + perm_b = np.argsort(x_b.reshape((-1, ))) + inv_perm_a = np.argsort(perm_a) + inv_perm_b = np.argsort(perm_b) + + M = dist(x_a[perm_a], x_b[perm_b], metric=metric) + + G_sorted, cost = emd_1d_sorted(a, b, M) + G = G_sorted[inv_perm_a, :][:, inv_perm_b] + if log: + log = {} + log['cost'] = cost + return G, log + return G \ No newline at end of file diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 83ee6aa..a3d189d 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) return G, cost, alpha, beta, result_code + + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, + np.ndarray[double, ndim=1, mode="c"] v_weights, + np.ndarray[double, ndim=2, mode="c"] M): + r""" + Roro's stuff + """ + cdef double cost = 0. + cdef int n = u_weights.shape[0] + cdef int m = v_weights.shape[0] + + cdef int i = 0 + cdef double w_i = u_weights[0] + cdef int j = 0 + cdef double w_j = v_weights[0] + + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), + dtype=np.float64) + while i < n and j < m: + if w_i < w_j or j == m - 1: + cost += M[i, j] * w_i + G[i, j] = w_i + i += 1 + w_j -= w_i + w_i = u_weights[i] + else: + cost += M[i, j] * w_j + G[i, j] = w_j + j += 1 + w_i -= w_j + w_j = v_weights[j] + return G, cost \ No newline at end of file diff --git a/test/test_ot.py b/test/test_ot.py index 7652394..7008002 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -46,6 +46,32 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) +def test_emd1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + u = np.random.randn(n, 1) + v = np.random.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + + # check G is similar + np.testing.assert_allclose(G, G_1d) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3 From 18502d6861a4977cbade957f2e48eeb8dbb55414 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 21 Jun 2019 11:21:08 +0200 Subject: Sparse G matrix for EMD1d + standard metrics computed without cdist --- ot/__init__.py | 4 ++-- ot/lp/emd_wrap.pyx | 29 +++++++++++++++++++++-------- test/test_ot.py | 23 ++++++++++++++++++----- 3 files changed, 41 insertions(+), 15 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index 5d5b700..f0e526c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -22,7 +22,7 @@ from . import smooth from . import stochastic # OT functions -from .lp import emd, emd2, emd_1d +from .lp import emd, emd2, emd_1d, emd2_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm @@ -32,5 +32,5 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.5.1" __all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', + 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 2966206..ab88d7f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod @cython.wraparound(False) def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, np.ndarray[double, ndim=1, mode="c"] v_weights, - np.ndarray[double, ndim=2, mode="c"] u, - np.ndarray[double, ndim=2, mode="c"] v, + np.ndarray[double, ndim=1, mode="c"] u, + np.ndarray[double, ndim=1, mode="c"] v, str metric='sqeuclidean'): r""" Roro's stuff @@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef double m_ij = 0. - cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m), + cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), dtype=np.float64) + cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int) + cdef int cur_idx = 0 while i < n and j < m: - m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), - metric=metric)[0, 0] + if metric == 'sqeuclidean': + m_ij = (u[i] - v[j]) ** 2 + elif metric == 'cityblock' or metric == 'euclidean': + m_ij = np.abs(u[i] - v[j]) + else: + m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), + metric=metric)[0, 0] if w_i < w_j or j == m - 1: cost += m_ij * w_i - G[i, j] = w_i + G[cur_idx] = w_i + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j i += 1 w_j -= w_i w_i = u_weights[i] else: cost += m_ij * w_j - G[i, j] = w_j + G[cur_idx] = w_j + indices[cur_idx, 0] = i + indices[cur_idx, 1] = j j += 1 w_i -= w_j w_j = v_weights[j] - return G, cost \ No newline at end of file + cur_idx += 1 + return G[:cur_idx], indices[:cur_idx], cost diff --git a/test/test_ot.py b/test/test_ot.py index 7008002..2a2e0a5 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -7,6 +7,7 @@ import warnings import numpy as np +from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss @@ -37,7 +38,7 @@ def test_emd_emd2(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn @@ -46,12 +47,13 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd1d(): +def test_emd_1d_emd2_1d(): # test emd1d gives similar results as emd n = 20 m = 30 - u = np.random.randn(n, 1) - v = np.random.randn(m, 1) + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) M = ot.dist(u, v, metric='sqeuclidean') @@ -59,9 +61,20 @@ def test_emd1d(): wass = log["cost"] G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, ))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0)) # check G is similar np.testing.assert_allclose(G, G_1d) @@ -86,7 +99,7 @@ def test_emd_empty(): # check G is identity np.testing.assert_allclose(G, np.eye(n) / n) - # check constratints + # check constraints np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn -- cgit v1.2.3 From 9e1d74f44473deb1f4766329bb0d1c8af4dfdd73 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Fri, 21 Jun 2019 18:27:42 +0200 Subject: Started documenting --- ot/lp/__init__.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- test/test_ot.py | 8 +++--- 2 files changed, 79 insertions(+), 6 deletions(-) (limited to 'test') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index decff29..e9635a1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X -def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): +def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the OT matrix + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + Uses the algorithm proposed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source histogram (uniform weight if empty list) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target histogram (uniform weight if empty list) + a : (ns,) ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) ndarray, float64 + Target histogram (uniform weight if empty list) + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. + Due to implementation details, this function runs faster when + dense is set to False. + metric: str, optional (default='sqeuclidean') + Metric to be used. Has to be a string. + Due to implementation details, this function runs faster when + `'sqeuclidean'` or `'euclidean'` metrics are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + perform automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [0., 2.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(a, b, x_a, x_b) + array([[ 0.5, 0. ], + [ 0. , 0.5]]) + + References + ---------- + + .. [1] TODO + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): return G -def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False): +def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False): """Solves the Earth Movers distance problem between 1d measures and returns the loss diff --git a/test/test_ot.py b/test/test_ot.py index 2a2e0a5..6d6ea26 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -59,10 +59,10 @@ def test_emd_1d_emd2_1d(): G, log = ot.emd([], [], M, log=True) wass = log["cost"] - G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True) + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False) + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) # check loss is similar np.testing.assert_allclose(wass, wass1d) @@ -82,7 +82,7 @@ def test_emd_1d_emd2_1d(): # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) v = np.random.randn(m, 2) - np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v) + np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], []) def test_emd_empty(): -- cgit v1.2.3 From 0d333e004636f5d25edea6bb195e8e4d9a95ba98 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Thu, 27 Jun 2019 10:23:32 +0200 Subject: Improved tests and docs for wasserstein_1d --- ot/__init__.py | 5 +++-- ot/lp/__init__.py | 13 ++++++------- ot/lp/emd_wrap.pyx | 3 ++- test/test_ot.py | 23 +++++++++++++++++++++++ 4 files changed, 34 insertions(+), 10 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index f0e526c..5bd9bb3 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -22,7 +22,7 @@ from . import smooth from . import stochastic # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d +from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, wasserstein2_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .da import sinkhorn_lpl1_mm @@ -32,5 +32,6 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.5.1" __all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d', + 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 719032b..76c9ec0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -530,13 +530,13 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): - """Solves the Wasserstein distance problem between 1d measures and returns + """Solves the p-Wasserstein distance problem between 1d measures and returns the OT matrix .. math:: - \gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij} - |x_a[i] - x_b[j]|^p \right)^{1/p} + \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij} + |x_a[i] - x_b[j]|^p \\right)^{1/p} s.t. \gamma 1 = a, \gamma^T 1= b, @@ -617,15 +617,14 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): dense=dense, log=log) -def wasserstein2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., - dense=True, log=False): - """Solves the Wasserstein distance problem between 1d measures and returns +def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): + """Solves the p-Wasserstein distance problem between 1d measures and returns the loss .. math:: \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij} - |x_a[i] - x_b[j]|^p \right)^{1/p} + |x_a[i] - x_b[j]|^p \\right)^{1/p} s.t. \gamma 1 = a, \gamma^T 1= b, diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 7134136..42b848f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -13,6 +13,7 @@ cimport numpy as np from ..utils import dist cimport cython +cimport libc.math as math import warnings @@ -159,7 +160,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, elif metric == 'cityblock' or metric == 'euclidean': m_ij = abs(u[i] - v[j]) elif metric == 'minkowski': - m_ij = abs(u[i] - v[j]) ** p + m_ij = math.pow(abs(u[i] - v[j]), p) else: m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), metric=metric)[0, 0] diff --git a/test/test_ot.py b/test/test_ot.py index 6d6ea26..48423e7 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -85,6 +85,29 @@ def test_emd_1d_emd2_1d(): np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], []) +def test_wass_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + + G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True) + wass1d = log["cost"] + + # check loss is similar + np.testing.assert_allclose(np.sqrt(wass), wass1d) + + # check G is similar + np.testing.assert_allclose(G, G_1d) + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3 From c92e595009ad5e2ae6d4b2c040556cffb6316847 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Thu, 27 Jun 2019 11:08:15 +0200 Subject: Wasserstein defined as the cost itself (do not return transportation matrix) --- ot/__init__.py | 4 +- ot/lp/__init__.py | 125 +++++------------------------------------------------- test/test_ot.py | 6 +-- 3 files changed, 13 insertions(+), 122 deletions(-) (limited to 'test') diff --git a/ot/__init__.py b/ot/__init__.py index 730aa4f..1b3c2fb 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -23,7 +23,7 @@ from . import stochastic from . import unbalanced # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, wasserstein2_1d +from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced from .da import sinkhorn_lpl1_mm @@ -35,6 +35,6 @@ __version__ = "0.5.1" __all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d', + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', "barycenter_unbalanced"] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 76c9ec0..a3f5b8d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -21,7 +21,7 @@ from .cvx import barycenter from ..utils import dist __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'wasserstein2_1d'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d'] def emd(a, b, M, numItermax=100000, log=False): @@ -529,9 +529,9 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, return cost -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): +def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): """Solves the p-Wasserstein distance problem between 1d measures and returns - the OT matrix + the distance .. math:: @@ -560,22 +560,11 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): Target histogram (default is uniform weight) p: float, optional (default=1.0) The order of the p-Wasserstein distance to be computed - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. Returns ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost + dist: float + p-Wasserstein distance Examples @@ -590,96 +579,8 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.wasserstein_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.wasserstein_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost - instead of the transportation matrix) - """ - if log: - G, log = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=dense, log=log) - log['cost'] = np.power(log['cost'], 1. / p) - return G, log - return emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=dense, log=log) - - -def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): - """Solves the p-Wasserstein distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij} - |x_a[i] - x_b[j]|^p \\right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein2_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein2_1d(x_a, x_b, a, b) 0.5 - >>> ot.wasserstein2_1d(x_a, x_b) + >>> ot.wasserstein_1d(x_a, x_b) 0.5 References @@ -690,14 +591,8 @@ def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False): See Also -------- - ot.lp.emd2_1d : EMD for 1d distributions - ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the - transportation matrix instead of the cost) + ot.lp.emd_1d : EMD for 1d distributions """ - if log: - cost, log = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=dense, log=log) - cost = np.power(cost, 1. / p) - return cost, log - return np.power(emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=dense, log=log), 1. / p) \ No newline at end of file + cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, + dense=False, log=False) + return np.power(cost_emd, 1. / p) diff --git a/test/test_ot.py b/test/test_ot.py index 48423e7..3c4ac11 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -98,15 +98,11 @@ def test_wass_1d(): G, log = ot.emd([], [], M, log=True) wass = log["cost"] - G_1d, log = ot.wasserstein_1d(u, v, [], [], p=2., log=True) - wass1d = log["cost"] + wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) # check loss is similar np.testing.assert_allclose(np.sqrt(wass), wass1d) - # check G is similar - np.testing.assert_allclose(G, G_1d) - def test_emd_empty(): # test emd and emd2 for simple identity -- cgit v1.2.3 From 93a74fe4d477e1735e9ce21ee4113281f58b4dcf Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 1 Jul 2019 11:02:11 +0200 Subject: Explicit doctest call in travis + removed uneffective doctest in test_ot --- .travis.yml | 2 +- test/test_ot.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) (limited to 'test') diff --git a/.travis.yml b/.travis.yml index 50ff22c..d6b4232 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,5 +32,5 @@ install: script: - python setup.py develop - flake8 examples/ ot/ test/ - - python -m pytest -v test/ --cov=ot + - python -m pytest -v test/ ot/ --doctest-modules --cov=ot # - py.test ot test diff --git a/test/test_ot.py b/test/test_ot.py index 3c4ac11..ac86602 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,16 +14,6 @@ from ot.datasets import make_1D_gauss as gauss import pytest -def test_doctest(): - import doctest - - # test lp solver - doctest.testmod(ot.lp, verbose=True) - - # test bregman solver - doctest.testmod(ot.bregman, verbose=True) - - def test_emd_emd2(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3 From b05d315b0994d328029d4a4fc082f6994e7f06d1 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 1 Jul 2019 11:06:26 +0200 Subject: Moved GPU doctests to test_gpu for tests not to fail if no GPU available --- ot/gpu/bregman.py | 11 ----------- test/test_gpu.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 11 deletions(-) (limited to 'test') diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 978b307..2e2df83 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -70,17 +70,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, log : dict log dictionary return only if log==True in parameters - Examples - -------- - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.sinkhorn(a,b,M,1) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) - References ---------- diff --git a/test/test_gpu.py b/test/test_gpu.py index 6b7fdd4..47b8b6d 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -15,6 +15,16 @@ except ImportError: nogpu = True +@pytest.mark.skipif(nogpu, reason="No GPU available") +def test_gpu_old_doctests(): + a = [.5, .5] + b = [.5, .5] + M = [[0., 1.], [1., 0.]] + G = ot.sinkhorn(a, b, M, 1) + np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]])) + + @pytest.mark.skipif(nogpu, reason="No GPU available") def test_gpu_dist(): -- cgit v1.2.3 From 64dba525bb5e0ac7952871df859df59fecf19a65 Mon Sep 17 00:00:00 2001 From: Romain Tavenard Date: Mon, 1 Jul 2019 14:56:55 +0200 Subject: Formatting --- test/test_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_gpu.py b/test/test_gpu.py index 47b8b6d..8e62a74 100644 --- a/test/test_gpu.py +++ b/test/test_gpu.py @@ -22,7 +22,7 @@ def test_gpu_old_doctests(): M = [[0., 1.], [1., 0.]] G = ot.sinkhorn(a, b, M, 1) np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]])) + [0.13447071, 0.36552929]])) @pytest.mark.skipif(nogpu, reason="No GPU available") -- cgit v1.2.3 From d3236cf0cab000b5604f8ede9ebcbdc19d8c213f Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 3 Jul 2019 13:31:55 +0200 Subject: test raise with pytets in test_emd_1d_emd2_1d --- Makefile | 4 ++-- ot/__init__.py | 5 ++++- test/test_ot.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/Makefile b/Makefile index 84a644b..4cdb7d1 100644 --- a/Makefile +++ b/Makefile @@ -42,10 +42,10 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - $(PYTHON) -m pytest -v test/ --cov=ot --cov-report html:cov_html + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html pytest : FORCE - $(PYTHON) -m pytest -v test/ --cov=ot + $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot uploadpypi : #python setup.py register diff --git a/ot/__init__.py b/ot/__init__.py index ad7b982..35d2ddd 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -10,7 +10,10 @@ a number of functions described below. :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic` - The other sub-modules are not imported due to additional dependencies. + The following sub-modules are not imported due to additional dependencies: + + - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. + - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. """ diff --git a/test/test_ot.py b/test/test_ot.py index ac86602..dacae0a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -72,7 +72,8 @@ def test_emd_1d_emd2_1d(): # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) v = np.random.randn(m, 2) - np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], []) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) def test_wass_1d(): -- cgit v1.2.3 From 5c0ed104b2890c609bdadfe0fcb0e836ba7a6ef1 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 22 Jul 2019 14:54:01 +0200 Subject: add unbalanced tests with stabilization --- test/test_unbalanced.py | 116 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 39 deletions(-) (limited to 'test') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 1395fe1..fc7aa5e 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -8,8 +8,10 @@ import numpy as np import ot import pytest +from scipy.misc import logsumexp -@pytest.mark.parametrize("method", ["sinkhorn"]) + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -23,29 +25,34 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, stopThr=1e-10, method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - u_final = (a / K.dot(log["v"])) ** fi + # in log-domain + fi = mu / (mu + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16) + logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - alpha=alpha, + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, stopThr=1e-10, method=method, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - - u_final = (a[:, None] / K.dot(log["v"])) ** fi + # in log-domain + fi = mu / (mu + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) assert len(loss) == b.shape[1] +def test_stabilized_vs_sinkhorn(): + # test if stable version matches sinkhorn + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + mu = 1. + G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon, + mu=mu, + log=True) + G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) + + def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -92,27 +127,30 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + mu = 1. - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu, stopThr=1e-10, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (q[:, None] / K.T.dot(log["u"])) ** fi - u_final = (A / K.dot(log["v"])) ** fi + fi = mu / (mu + epsilon) + logA = np.log(A + 1e-16) + logq = np.log(q + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logq - logKtu) + u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', - 'sinkhorn_epsilon_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -126,21 +164,21 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. - alpha = 1. + mu = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, method=method) -- cgit v1.2.3 From 09f3f640fc46ba4905d5508b704f2e5a90dda295 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 23 Jul 2019 21:28:30 +0200 Subject: fix issue 94 + add test --- ot/bregman.py | 10 +++++++--- test/test_bregman.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index f39145d..70e4208 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -765,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) if log: - log['logu'] = alpha / reg + np.log(u) - log['logv'] = beta / reg + np.log(v) + if nbb: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + log['logu'] = logu + log['logv'] = logv log['alpha'] = alpha + reg * np.log(u) log['beta'] = beta + reg * np.log(v) log['warmstart'] = (log['alpha'], log['beta']) diff --git a/test/test_bregman.py b/test/test_bregman.py index 7f4972c..83ebba8 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence(): emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn np.testing.assert_allclose( emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn + + +def test_stabilized_vs_sinkhorn_multidim(): + # test if stable version matches sinkhorn + # for multidimensional inputs + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + log=True) + G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2) -- cgit v1.2.3 From a507556b1901e16351c211e69b38d8d74ac2bc3d Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Tue, 23 Jul 2019 21:51:53 +0200 Subject: rebase unbalanced --- test/test_unbalanced.py | 116 ++++++++++++++++-------------------------------- 1 file changed, 39 insertions(+), 77 deletions(-) (limited to 'test') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc7aa5e..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -8,10 +8,8 @@ import numpy as np import ot import pytest -from scipy.misc import logsumexp - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -25,34 +23,29 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, stopThr=1e-10, method=method, log=True) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): - # test if stable version matches sinkhorn - n = 100 - - # Gaussian distributions - a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std - b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) - b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) - - # creating matrix A containing all distributions - b = np.vstack((b1, b2)).T - - M = ot.utils.dist0(n) - M /= np.median(M) - epsilon = 0.1 - mu = 1. - G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon, - mu=mu, - log=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, - method="sinkhorn", log=True) - - np.testing.assert_allclose(G, G2) - - def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -127,30 +92,27 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu, + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) # check fixed point equations - fi = mu / (mu + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logq - logKtu) - u_final = fi * (logA - logKv) + fi = alpha / (alpha + epsilon) + v_final = (q[:, None] / K.T.dot(log["u"])) ** fi + u_final = (A / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -164,21 +126,21 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) -- cgit v1.2.3 From 9d4b786a036ac95989825beec819521089fb4feb Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 16:37:58 -0400 Subject: fixes for travis, added test, minor nits --- .travis.yml | 5 ++-- ot/da.py | 2 +- ot/utils.py | 4 +++- test/test_da.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/.travis.yml b/.travis.yml index 5e5694b..72fd29a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ matrix: python: 3.5 - os: linux sudo: required - python: 3.6 + python: 3.6 - os: linux sudo: required python: 2.7 @@ -21,7 +21,6 @@ before_install: - ./.travis/before_install.sh before_script: # configure a headless display to test plot generation - "export DISPLAY=:99.0" - - "sh -e /etc/init.d/xvfb start" - sleep 3 # give xvfb some time to start # command to install dependencies install: @@ -30,6 +29,8 @@ install: - pip install flake8 pytest "pytest-cov<2.6" - pip install . # command to run tests + check syntax style +services: + - xvfb script: - python setup.py develop - flake8 examples/ ot/ test/ diff --git a/ot/da.py b/ot/da.py index c1d9849..2af855d 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport): """ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', - max_iter=10, tol=10e-9, verbose=False, log=False, + max_iter=10, tol=1e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): diff --git a/ot/utils.py b/ot/utils.py index be839f8..a334fea 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -178,7 +178,9 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ - if norm == "median": + if norm is None: + pass + elif norm == "median": C /= float(np.median(C)) elif norm == "max": C /= float(np.max(C)) diff --git a/test/test_da.py b/test/test_da.py index f7f3a9d..9efd2d9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -245,6 +245,79 @@ def test_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 +def test_unbalanced_sinkhorn_transport_class(): + """test_sinkhorn_transport + """ + + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + otda = ot.da.UnbalancedSinkhornTransport() + + # test its computed + otda.fit(Xs=Xs, Xt=Xt) + assert hasattr(otda, "cost_") + assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + Xs_new, _ = make_data_classif('3gauss', ns + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # test inverse transform + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + Xt_new, _ = make_data_classif('3gauss2', nt + 1) + transp_Xt_new = otda.inverse_transform(Xt=Xt_new) + + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) + + # test fit_transform + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + + # test unsupervised vs semi-supervised mode + otda_unsup = ot.da.SinkhornTransport() + otda_unsup.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(otda_unsup.cost_) + + otda_semi = ot.da.SinkhornTransport() + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(otda_semi.cost_) + + # check that the cost matrix norms are indeed different + assert n_unsup != n_semisup, "semisupervised mode not working" + + # check everything runs well with log=True + otda = ot.da.SinkhornTransport(log=True) + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert len(otda.log_.keys()) != 0 + + def test_emd_transport_class(): """test_sinkhorn_transport """ -- cgit v1.2.3 From ce86d1476b32771d32b7e55566e7cab45bb57b3a Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 17:03:08 -0400 Subject: Fix in test: no margin constraints here --- test/test_da.py | 8 -------- 1 file changed, 8 deletions(-) (limited to 'test') diff --git a/test/test_da.py b/test/test_da.py index 9efd2d9..2a5e50e 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -267,14 +267,6 @@ def test_unbalanced_sinkhorn_transport_class(): assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) -- cgit v1.2.3 From cfdbbd21642c6082164b84db78c2ead07499a113 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Fri, 19 Jul 2019 17:04:14 +0200 Subject: remove square in convergence check add unbalanced with stabilization add unbalanced tests with stabilization fix doctest examples add xvfb in travis remove explicit call xvfb in travis change alpha to reg_m minor flake8 remove redundant sink definitions + better doc and naming add stabilized unbalanced barycenter + add not converged warnings add test for stable barycenter add generic barycenter func + make method funcs private fix typo + add method test for barycenters fix doc examples + add xml to gitignore fix whitespace in example change logsumexp import - scipy deprecation warning fix doctest improve naming + add stable barycenter in bregman add test for stable bar + test the method arg in bregman --- .gitignore | 3 + ot/__init__.py | 18 +- ot/bregman.py | 530 +++++++++++++++++++++----------- ot/unbalanced.py | 803 +++++++++++++++++++++++++++++++++++++++--------- pytest.ini | 0 test/test_bregman.py | 72 ++++- test/test_unbalanced.py | 163 +++++++--- 7 files changed, 1205 insertions(+), 384 deletions(-) create mode 100644 pytest.ini (limited to 'test') diff --git a/.gitignore b/.gitignore index 42a9aad..dadf84c 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,9 @@ coverage.xml *.mo *.pot +# xml +*.xml + # Django stuff: *.log local_settings.py diff --git a/ot/__init__.py b/ot/__init__.py index 35ae6fc..7d9615a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,7 +1,7 @@ """ -This is the main module of the POT toolbox. It provides easy access to -a number of sub-modules and functions described below. +This is the main module of the POT toolbox. It provides easy access to +a number of sub-modules and functions described below. .. note:: @@ -14,27 +14,27 @@ a number of sub-modules and functions described below. - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT problems. - - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov + - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov Wasserstein problems. - - :any:`ot.optim` contains generic solvers OT based optimization problems + - :any:`ot.optim` contains generic solvers OT based optimization problems - :any:`ot.da` contains classes and function related to Monge mapping estimation and Domain Adaptation (DA). - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers - - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein + - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein Discriminant Analysis. - - :any:`ot.utils` contains utility functions such as distance computation and - timing. + - :any:`ot.utils` contains utility functions such as distance computation and + timing. - :any:`ot.datasets` contains toy dataset generation functions. - :any:`ot.plot` contains visualization functions - :any:`ot.stochastic` contains stochastic solvers for regularized OT. - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT. .. warning:: - The list of automatically imported sub-modules is as follows: + The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` :py:mod:`ot.utils`, :py:mod:`ot.datasets`, :py:mod:`ot.gromov`, :py:mod:`ot.smooth` - :py:mod:`ot.stochastic` + :py:mod:`ot.stochastic` The following sub-modules are not imported due to additional dependencies: diff --git a/ot/bregman.py b/ot/bregman.py index 70e4208..2f27d58 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,10 +7,12 @@ Bregman projections for regularized OT # Nicolas Courty # Kilian Fatras # Titouan Vayer +# Hicham Janati # # License: MIT License import numpy as np +import warnings from .utils import unif, dist @@ -31,7 +33,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -40,12 +42,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 @@ -64,7 +66,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -103,30 +105,23 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) elif method.lower() == 'greenkhorn': - def sink(): - return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + return _greenkhorn(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log) elif method.lower() == 'sinkhorn_stabilized': - def sink(): - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': - def sink(): - return sinkhorn_epsilon_scaling( - a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_epsilon_scaling(a, b, M, reg, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) - - return sink() + raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -146,7 +141,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -155,12 +150,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 @@ -218,35 +213,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] """ - + b = np.asarray(b, dtype=np.float64) + if len(b.shape) < 2: + b = b[:, None] if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - def sink(): - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': - def sink(): - return sinkhorn_epsilon_scaling( - a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return _sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: - print('Warning : unknown method using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp(a, b, M, reg, **kwargs) + raise ValueError("Unknown method '%s'." % method) - b = np.asarray(b, dtype=np.float64) - if len(b.shape) < 2: - b = b[:, None] - return sink() - - -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def _sinkhorn_knopp(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -262,7 +247,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -271,12 +256,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 @@ -291,7 +276,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -331,25 +316,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] # init data - Nini = len(a) - Nfin = len(b) + dim_a = len(a) + dim_b = len(b) if len(b.shape) > 1: - nbb = b.shape[1] + n_hists = b.shape[1] else: - nbb = 0 + n_hists = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances - if nbb: - u = np.ones((Nini, nbb)) / Nini - v = np.ones((Nfin, nbb)) / Nfin + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b else: - u = np.ones(Nini) / Nini - v = np.ones(Nfin) / Nfin + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b # print(reg) @@ -384,13 +369,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - if nbb: - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + if n_hists: + np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b)**2 # violation of marginal + err = np.linalg.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -404,7 +388,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['u'] = u log['v'] = v - if nbb: # return only loss + if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log @@ -419,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): +def _greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -443,7 +427,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -451,12 +435,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) + b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 @@ -469,7 +453,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -481,7 +465,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.bregman.greenkhorn(a, b, M, 1) + >>> ot.bregman._greenkhorn(a, b, M, 1) array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) @@ -509,16 +493,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - n = a.shape[0] - m = b.shape[0] + dim_a = a.shape[0] + dim_b = b.shape[0] # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty_like(M) np.divide(M, -reg, out=K) np.exp(K, out=K) - u = np.full(n, 1. / n) - v = np.full(m, 1. / m) + u = np.full(dim_a, 1. / dim_a) + v = np.full(dim_b, 1. / dim_b) G = u[:, np.newaxis] * K * v[np.newaxis, :] viol = G.sum(1) - a @@ -571,8 +555,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= return G -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, - warmstart=None, verbose=False, print_period=20, log=False, **kwargs): +def _sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, + log=False, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization @@ -588,7 +573,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -599,11 +584,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) + b : ndarray, shape (dim_b,) samples in the target domain - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 @@ -622,7 +607,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -634,7 +619,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.bregman.sinkhorn_stabilized(a,b,M,1) + >>> ot.bregman._sinkhorn_stabilized(a, b, M, 1) array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) @@ -667,10 +652,10 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # test if multiple target if len(b.shape) > 1: - nbb = b.shape[1] + n_hists = b.shape[1] a = a[:, np.newaxis] else: - nbb = 0 + n_hists = 0 # init data na = len(a) @@ -687,8 +672,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, else: alpha, beta = warmstart - if nbb: - u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb + if n_hists: + u, v = np.ones((na, n_hists)) / na, np.ones((nb, n_hists)) / nb else: u, v = np.ones(na) / na, np.ones(nb) / nb @@ -720,13 +705,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # remove numerical problems and store them in K if np.abs(u).max() > tau or np.abs(v).max() > tau: - if nbb: + if n_hists: alpha, beta = alpha + reg * \ np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) - if nbb: - u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb + if n_hists: + u, v = np.ones((na, n_hists)) / na, np.ones((nb, n_hists)) / nb else: u, v = np.ones(na) / na, np.ones(nb) / nb K = get_K(alpha, beta) @@ -734,12 +719,15 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if cpt % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations - if nbb: - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + if n_hists: + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) else: transp = get_Gamma(alpha, beta, u, v) - err = np.linalg.norm((np.sum(transp, axis=0) - b))**2 + err = np.linalg.norm((np.sum(transp, axis=0) - b)) if log: log['err'].append(err) @@ -766,7 +754,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, cpt = cpt + 1 if log: - if nbb: + if n_hists: alpha = alpha[:, None] beta = beta[:, None] logu = alpha / reg + np.log(u) @@ -776,26 +764,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, log['alpha'] = alpha + reg * np.log(u) log['beta'] = beta + reg * np.log(v) log['warmstart'] = (log['alpha'], log['beta']) - if nbb: - res = np.zeros((nbb)) - for i in range(nbb): + if n_hists: + res = np.zeros((n_hists)) + for i in range(n_hists): res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) return res, log else: return get_Gamma(alpha, beta, u, v), log else: - if nbb: - res = np.zeros((nbb)) - for i in range(nbb): + if n_hists: + res = np.zeros((n_hists)) + for i in range(n_hists): res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) return res else: return get_Gamma(alpha, beta, u, v) -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, - tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): +def _sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, + numInnerItermax=100, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=10, + log=False, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -812,7 +802,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - M is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) @@ -823,18 +813,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne Parameters ---------- - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) + b : ndarray, shape (dim_b,) samples in the target domain - M : ndarray, shape (ns, nt) + M : ndarray, shape (dim_a, n_b) loss matrix reg : float Regularization term >0 tau : float thershold for max value in u or v for log scaling - tau : float - thershold for max value in u or v for log scaling warmstart : tuple of vectors if given then sarting values for alpha an beta log scalings numItermax : int, optional @@ -852,7 +840,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -864,7 +852,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] - >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) + >>> ot.bregman._sinkhorn_epsilon_scaling(a, b, M, 1) array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) @@ -893,8 +881,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] # init data - na = len(a) - nb = len(b) + dim_a = len(a) + dim_b = len(b) # nrelative umerical precision with 64 bits numItermin = 35 @@ -907,14 +895,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(na), np.zeros(nb) + alpha, beta = np.zeros(dim_a), np.zeros(dim_b) else: alpha, beta = warmstart def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((dim_a, 1)) + - beta.reshape((1, dim_b))) / reg) # print(np.min(K)) def get_reg(n): # exponential decreasing @@ -927,7 +915,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne regi = get_reg(cpt) - G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=( + G, logi = _sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=( alpha, beta), verbose=False, print_period=20, tau=tau, log=True) alpha = logi['alpha'] @@ -986,8 +974,8 @@ def projC(gamma, q): return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) -def barycenter(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): +def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -1005,13 +993,15 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Parameters ---------- - A : ndarray, shape (d,n) - n training distributions a_i of size d - M : ndarray, shape (d,d) - loss matrix for OT + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT reg : float - Regularization term >0 - weights : ndarray, shape (n,) + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + weights : ndarray, shape (n_hists,) Weights of each histogram a_i on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations @@ -1025,7 +1015,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Returns ------- - a : (d,) ndarray + a : (dim,) ndarray Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1036,8 +1026,70 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + """ + + if method.lower() == 'sinkhorn': + return _barycenter(A, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return _barycenter_stabilized(A, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + weights : ndarray, shape (n_hists,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + a : (dim,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + """ if weights is None: @@ -1082,6 +1134,136 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, return geometricBar(weights, UKv) +def _barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + with stabilization. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + + Parameters + ---------- + A : ndarray, shape (dim, n_hists) + n_hists training distributions a_i of size dim + M : ndarray, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + tau : float + thershold for max value in u or v for log scaling + weights : ndarray, shape (n_hists,) + Weights of each histogram a_i on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + u = A / (Kv + 1e-16) + Ktu = K.T.dot(u) + q = geometricBar(weights, Ktu) + Q = q[:, None] + v = Q / (Ktu + 1e-16) + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + print("YEAH absorbing") + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u * Kv - A).max() + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. @@ -1101,16 +1283,16 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 Parameters ---------- - A : ndarray, shape (n, w, h) - n distributions (2D images) of size w x h + A : ndarray, shape (n_hists, width, height) + n distributions (2D images) of size width x height reg : float Regularization term >0 - weights : ndarray, shape (n,) + weights : ndarray, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional @@ -1120,7 +1302,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 Returns ------- - a : ndarray, shape (w, h) + a : ndarray, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1214,15 +1396,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Parameters ---------- - a : ndarray, shape (d) + a : ndarray, shape (n_observed) observed distribution - D : ndarray, shape (d, n) + D : ndarray, shape (dim, dim) dictionary matrix - M : ndarray, shape (d, d) + M : ndarray, shape (dim, dim) loss matrix - M0 : ndarray, shape (n, n) + M0 : ndarray, shape (n_observed, n_observed) loss matrix - h0 : ndarray, shape (n,) + h0 : ndarray, shape (dim,) prior on h reg : float Regularization term >0 (Wasserstein data fitting) @@ -1242,7 +1424,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Returns ------- - a : ndarray, shape (d,) + a : ndarray, shape (dim,) Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1315,22 +1497,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI \gamma\geq 0 where : - - :math:`M` is the (ns,nt) metric cost matrix + - :math:`M` is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (ns, d) + X_s : ndarray, shape (dim_a, d) samples in the source domain - X_t : ndarray, shape (nt, d) + X_t : ndarray, shape (dim_b, d) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (ns,) + a : ndarray, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (nt,) + b : ndarray, shape (dim_b,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1344,7 +1526,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (dim_a, n_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1352,11 +1534,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI Examples -------- - >>> n_s = 2 - >>> n_t = 2 + >>> n_a = 2 + >>> n_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) - >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> X_s = np.reshape(np.arange(n_a), (dim_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_b), (dim_b, 1)) >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -1405,22 +1587,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num \gamma\geq 0 where : - - :math:`M` is the (ns,nt) metric cost matrix + - :math:`M` is the (dim_a, n_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (ns, d) + X_s : ndarray, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (nt, d) + X_t : ndarray, shape (n_samples_b, d) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (ns,) + a : ndarray, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (nt,) + b : ndarray, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1434,7 +1616,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (n_samples_a, n_samples_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1442,11 +1624,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Examples -------- - >>> n_s = 2 - >>> n_t = 2 + >>> n_a = 2 + >>> n_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) - >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) array([4.53978687e-05]) @@ -1513,22 +1695,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b\geq 0 where : - - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) + - :math:`M` (resp. :math:`M_a, M_b`) is the (dim_a, n_b) metric cost matrix (resp (dim_a, ns) and (dim_b, nt)) - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`a` and :math:`b` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (ns, d) + X_s : ndarray, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (nt, d) + X_t : ndarray, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (ns,) + a : ndarray, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (nt,) + b : ndarray, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1541,18 +1723,18 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - gamma : ndarray, shape (ns, nt) + gamma : ndarray, shape (n_samples_a, n_samples_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters Examples -------- - >>> n_s = 2 - >>> n_t = 4 + >>> n_a = 2 + >>> n_b = 4 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) - >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS array([1.499...]) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 0f0692e..3f71d28 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -9,51 +9,56 @@ Regularized Unbalanced OT from __future__ import division import warnings import numpy as np +from scipy.special import logsumexp + # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the unbalanced entropic regularization optimal transport problem and return the loss + Solve the unbalanced entropic regularization optimal transport problem + and return the OT plan The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns, nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshol on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_stabilized_unbalanced: + Unbalanced Stabilized sinkhorn [9][10] + ot.unbalanced.sinkhorn_reg_scaling_unbalanced: + Unbalanced Sinkhorn with epslilon scaling [9][10] """ if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') - - return sink() + raise ValueError("Unknown method '%s'." % method) -def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', - numItermax=1000, stopThr=1e-9, verbose=False, +def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', + numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" - Solve the entropic regularization unbalanced optimal transport problem and return the loss + Solve the entropic regularization unbalanced optimal transport problem and + return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + 'sinkhorn_reg_scaling', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional @@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', Returns ------- - W : (nt) ndarray or float - Optimal transportation matrix for the given parameters + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` log : dict - log dictionary return only if log==True in parameters + log dictionary returned only if `log` is `True` Examples -------- @@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems + (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] """ - - if method.lower() == 'sinkhorn': - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - - elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']: - warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - - def sink(): - return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) - else: - raise ValueError('Unknown method. Using classic Sinkhorn Knopp') - b = np.asarray(b, dtype=np.float64) if len(b.shape) < 2: b = b[:, None] - - return sink() + if method.lower() == 'sinkhorn': + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError('Unknown method %s.' % method) -def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b) + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) s.t. \gamma\geq 0 where : - - M is the (ns, nt) metric cost matrix + - M is the (dim_a, dim_b) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights + - a and b are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ @@ -256,16 +286,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 - alpha : float + reg_m: float Marginal relaxation term > 0 numItermax : int, optional Max number of iterations @@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, Returns ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` Examples -------- @@ -291,16 +326,20 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] - >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) + >>> ot.unbalanced._sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. - .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 See Also -------- @@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - n_a, n_b = M.shape + dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(n_a, dtype=np.float64) / n_a + a = np.ones(dim_a, dtype=np.float64) / dim_a if len(b) == 0: - b = np.ones(n_b, dtype=np.float64) / n_b + b = np.ones(dim_b, dtype=np.float64) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((n_a, 1)) / n_a - v = np.ones((n_b, n_hists)) / n_b - a = a.reshape(n_a, 1) + u = np.ones((dim_a, 1)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) else: - u = np.ones(n_a) / n_a - v = np.ones(n_b) / n_b + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b - # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) - # print(np.min(K)) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) cpt = 0 err = 1. @@ -371,8 +408,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -383,8 +421,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, cpt += 1 if log: - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) if n_hists: # return only loss res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) @@ -401,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, return u[:, None] * K * v[None, :] -def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A +def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, + stopThr=1e-6, verbose=False, log=False, + **kwargs): + r""" + Solve the entropic regularization unbalanced optimal transport + problem and return the loss + + The function solves the following optimization problem using log-domain + stabilization as proposed in [10]: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + + s.t. + \gamma\geq 0 + where : + + - M is the (dim_a, dim_b) metric cost matrix + - :math:`\Omega` is the entropic regularization + term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension dim_a + b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + One or multiple unnormalized histograms of dimension dim_b + If many, compute all the OT distances (a, b_i) + M : np.ndarray (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + tau : float + thershold for max value in u or v for log scaling + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + gamma : (dim_a x dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + else: + ot_distance : (n_hists,) ndarray + the OT distance between `a` and each of the histograms `b_i` + log : dict + log dictionary returned only if `log` is `True` + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced._sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) + array([[0.51122823, 0.18807035], + [0.18807035, 0.51122823]]) + + References + ---------- + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : + Learning with a Wasserstein Loss, Advances in Neural Information + Processing Systems (NIPS) 2015 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = np.ones(dim_a, dtype=np.float64) / dim_a + if len(b) == 0: + b = np.ones(dim_b, dtype=np.float64) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if n_hists: + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b + a = a.reshape(dim_a, 1) + else: + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim_a) + beta = np.zeros(dim_b) + while (err > stopThr and cpt < numItermax): + uprev = u + vprev = v + + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + + if n_hists: + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((a / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + v = ((b / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + if n_hists: + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + else: + alpha = alpha + reg * np.log(np.max(u)) + beta = beta + reg * np.log(np.max(v)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + u = uprev + v = vprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), + 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + cpt = cpt + 1 + + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if n_hists: + logu = alpha[:, None] / reg + np.log(u) + logv = beta[:, None] / reg + np.log(v) + else: + logu = alpha / reg + np.log(u) + logv = beta / reg + np.log(v) + if log: + log['logu'] = logu + log['logv'] = logv + if n_hists: # return only loss + res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + + logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) + res = np.exp(res) + if log: + return res, log + else: + return res + + else: # return OT matrix + ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + if log: + return ot_matrix, log + else: + return ot_matrix + + +def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. The function solves the following optimization problem: @@ -412,28 +665,184 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - - alpha is the marginal relaxation hyperparameter - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of + matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ Parameters ---------- - A : np.ndarray (d,n) - n training distributions a_i of size d - M : np.ndarray (d,d) - loss matrix for OT + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. reg : float Entropy regularization term > 0 - alpha : float + reg_m : float Marginal relaxation term > 0 - weights : np.ndarray (n,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + tau : float + Stabilization threshold for log domain absorption. + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, + G. (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + + """ + dim, n_hists = A.shape + if weights is None: + weights = np.ones(n_hists) / n_hists + else: + assert(len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + fi = reg_m / (reg_m + reg) + + u = np.ones((dim, n_hists)) / dim + v = np.ones((dim, n_hists)) / dim + + # print(reg) + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + fi = reg_m / (reg_m + reg) + + cpt = 0 + err = 1. + alpha = np.zeros(dim) + beta = np.zeros(dim) + q = np.ones(dim) / dim + while (err > stopThr and cpt < numItermax): + qprev = q + Kv = K.dot(v) + f_alpha = np.exp(- alpha / (reg + reg_m)) + f_beta = np.exp(- beta / (reg + reg_m)) + f_alpha = f_alpha[:, None] + f_beta = f_beta[:, None] + u = ((A / (Kv + 1e-16)) ** fi) * f_alpha + Ktu = K.T.dot(u) + q = (Ktu ** (1 - fi)) * f_beta + q = q.dot(weights) ** (1 / (1 - fi)) + Q = q[:, None] + v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta + absorbing = False + if (u > tau).any() or (v > tau).any(): + absorbing = True + alpha = alpha + reg * np.log(np.max(u, 1)) + beta = beta + reg * np.log(np.max(v, 1)) + K = np.exp((alpha[:, None] + beta[None, :] - + M) / reg) + v = np.ones_like(v) + Kv = K.dot(v) + if (np.any(Ktu == 0.) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % cpt) + q = qprev + break + if (cpt % 10 == 0 and not absorbing) or cpt == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = abs(q - qprev).max() / max(abs(q).max(), + abs(qprev).max(), 1.) + if log: + log['err'].append(err) + if verbose: + if cpt % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + if err > stopThr: + warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + + "Try a larger entropy `reg` or a lower mass `reg_m`." + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = cpt + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) + return q, log + else: + return q + + +def _barycenter_unbalanced(A, M, reg, reg_m, weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -442,7 +851,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, Returns ------- - a : (d,) ndarray + a : (dim,) ndarray Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -451,12 +860,16 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. """ - p, n_hists = A.shape + dim, n_hists = A.shape if weights is None: weights = np.ones(n_hists) / n_hists else: @@ -467,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, K = np.exp(- M / reg) - fi = alpha / (alpha + reg) + fi = reg_m / (reg_m + reg) - v = np.ones((p, n_hists)) / p - u = np.ones((p, 1)) / p + v = np.ones((dim, n_hists)) / dim + u = np.ones((dim, 1)) / dim cpt = 0 err = 1. @@ -499,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ - np.sum((v - vprev) ** 2) / np.sum((v) ** 2) + err_u = abs(u - uprev).max() + err_u /= max(abs(u).max(), abs(uprev).max(), 1.) + err_v = abs(v - vprev).max() + err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: @@ -512,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000, cpt += 1 if log: log['niter'] = cpt - log['u'] = u - log['v'] = v + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q + + +def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + r"""Compute the entropic unbalanced wasserstein barycenter of A. + + The function solves the following optimization problem with a + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized + Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + Parameters + ---------- + A : np.ndarray (dim, n_hists) + `n_hists` training distributions a_i of dimension dim + M : np.ndarray (dim, dim) + ground metric matrix for OT. + reg : float + Entropy regularization term > 0 + reg_m: float + Marginal relaxation term > 0 + weights : np.ndarray (n_hists,) optional + Weight of each distribution (barycentric coodinates) + If None, uniform weights are used. + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + a : (dim,) ndarray + Unbalanced Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprin + arXiv:1607.05816. + + """ + + if method.lower() == 'sinkhorn': + return _barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + + elif method.lower() == 'sinkhorn_stabilized': + return _barycenter_unbalanced_stabilized(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: + warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') + return _barycenter_unbalanced(A, M, reg, reg_m, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e69de29 diff --git a/test/test_bregman.py b/test/test_bregman.py index 83ebba8..f70df10 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ import numpy as np import ot +import pytest def test_sinkhorn(): @@ -71,13 +72,11 @@ def test_sinkhorn_variants(): Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) Ges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, Gerr) np.testing.assert_allclose(G0, G_green, atol=1e-5) print(G0, G_green) @@ -96,18 +95,17 @@ def test_sinkhorn_variants_log(): Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) - Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, Gerr) np.testing.assert_allclose(G0, G_green, atol=1e-5) print(G0, G_green) -def test_bary(): +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_barycenter(method): n_bins = 100 # nb bins @@ -126,14 +124,42 @@ def test_bary(): weights = np.array([1 - alpha, alpha]) # wasserstein - reg = 1e-3 - bary_wass = ot.bregman.barycenter(A, M, reg, weights) + reg = 1e-2 + bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method) np.testing.assert_allclose(1, np.sum(bary_wass)) ot.bregman.barycenter(A, M, reg, log=True, verbose=True) +def test_barycenter_stabilization(): + + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + # wasserstein + reg = 1e-2 + bar_stable = ot.bregman.barycenter(A, M, reg, weights, + method="sinkhorn_stabilized", + stopThr=1e-8) + bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", + stopThr=1e-8) + np.testing.assert_allclose(bar, bar_stable) + + def test_wasserstein_bary_2d(): size = 100 # size of a square image @@ -279,3 +305,35 @@ def test_stabilized_vs_sinkhorn_multidim(): method="sinkhorn", log=True) np.testing.assert_allclose(G, G2) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) + A = rng.rand(n, 2) + M = ot.dist(x, x) + epsilon = 1. + + for method in IMPLEMENTED_METHODS: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + for method in ONLY_1D_methods: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + with pytest.raises(ValueError): + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 1395fe1..ca1efba 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -7,9 +7,12 @@ import numpy as np import ot import pytest +from ot.unbalanced import barycenter_unbalanced +from scipy.special import logsumexp -@pytest.mark.parametrize("method", ["sinkhorn"]) + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -23,29 +26,35 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, method=method, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - u_final = (a / K.dot(log["v"])) ** fi + # in log-domain + fi = reg_m / (reg_m + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16) + logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -59,28 +68,59 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + reg_m = 1. loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - alpha=alpha, - stopThr=1e-10, method=method, + reg_m=reg_m, + method=method, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (b / K.T.dot(log["u"])) ** fi - - u_final = (a[:, None] / K.dot(log["v"])) ** fi + # in log-domain + fi = reg_m / (reg_m + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) assert len(loss) == b.shape[1] -def test_unbalanced_barycenter(): +def test_stabilized_vs_sinkhorn(): + # test if stable version matches sinkhorn + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + reg_m = 1. + G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + reg_m=reg_m, + log=True) + G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_barycenter(method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -92,27 +132,56 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - alpha = 1. - K = np.exp(- M / epsilon) + reg_m = 1. - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, - stopThr=1e-10, - log=True) + q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method, log=True) # check fixed point equations - fi = alpha / (alpha + epsilon) - v_final = (q[:, None] / K.T.dot(log["u"])) ** fi - u_final = (A / K.dot(log["v"])) ** fi + fi = reg_m / (reg_m + epsilon) + logA = np.log(A + 1e-16) + logq = np.log(q + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logq - logKtu) + u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["u"], atol=1e-05) + u_final, log["logu"], atol=1e-05) np.testing.assert_allclose( - v_final, log["v"], atol=1e-05) + v_final, log["logv"], atol=1e-05) + + +def test_barycenter_stabilized_vs_sinkhorn(): + # test generalized sinkhorn for unbalanced OT barycenter + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + A = rng.rand(n, 2) + + # make dists unbalanced + A = A * np.array([1, 4])[None, :] + M = ot.dist(x, x) + epsilon = 0.5 + reg_m = 10 + + qstable, log = barycenter_unbalanced(A, M, reg=epsilon, + reg_m=reg_m, log=True, + tau=100, + method="sinkhorn_stabilized", + ) + q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method="sinkhorn", + log=True) + + np.testing.assert_allclose( + q, qstable, atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', - 'sinkhorn_epsilon_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -123,24 +192,30 @@ def test_implemented_methods(): # make dists unbalanced b = ot.utils.unif(n) * 1.5 - + A = rng.rand(n, 2) M = ot.dist(x, x) epsilon = 1. - alpha = 1. + reg_m = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) -- cgit v1.2.3 From 2a32e2ea64d0d5096953a9b8259b0507fa58dca5 Mon Sep 17 00:00:00 2001 From: Kilian Date: Wed, 13 Nov 2019 13:55:24 +0100 Subject: fix log bug in gromov_wasserstein2 --- ot/gromov.py | 156 ++++++++++++++++++++++++++-------------------------- test/test_gromov.py | 4 ++ 2 files changed, 81 insertions(+), 79 deletions(-) (limited to 'test') diff --git a/ot/gromov.py b/ot/gromov.py index 699ae4c..9869341 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - - H : entropy Parameters ---------- @@ -343,6 +342,83 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): + """ + Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + + The function solves the following optimization problem: + + .. math:: + GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space. + q : ndarray, shape (nt,) + Distribution in the target space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + log : dict + convergence information and Coupling marix + + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) + log_gw['T'] = res + if log: + return log_gw['gw_dist'], log_gw + else: + return log_gw['gw_dist'] + + def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ Computes the FGW transport between two graphs see [24] @@ -506,84 +582,6 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) - - The function solves the following optimization problem: - - .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric cost matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space. - q : ndarray, shape (nt,) - Distribution in the target space. - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - log : dict - convergence information and Coupling marix - - References - ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - """ - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - G0 = p[:, None] * q[None, :] - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - log['T'] = res - if log: - return log['gw_dist'], log - else: - return log['gw_dist'] - - def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): """ diff --git a/test/test_gromov.py b/test/test_gromov.py index 70fa83f..43da9fc 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -44,10 +44,14 @@ def test_gromov(): gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + G = log['T'] np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False + # check constratints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov -- cgit v1.2.3 From 0280a3441b09c781035cda3b74213ec92026ff9e Mon Sep 17 00:00:00 2001 From: Kilian Date: Fri, 15 Nov 2019 16:10:37 +0100 Subject: fix bug numItermax emd in cg --- ot/optim.py | 6 ++++-- test/test_optim.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/ot/optim.py b/ot/optim.py index 0abd9e9..4012e0d 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -134,7 +134,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, return alpha, fc, f_val -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, +def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -172,6 +172,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, initial guess (default is indep joint density) numItermax : int, optional Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd stopThr : float, optional Stop threshol on the relative variation (>0) stopThr2 : float, optional @@ -238,7 +240,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Mi += Mi.min() # solve linear program - Gc = emd(a, b, Mi) + Gc = emd(a, b, Mi, numItermax=numItermaxEmd) deltaG = Gc - G diff --git a/test/test_optim.py b/test/test_optim.py index ae31e1f..aade36e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -37,6 +37,39 @@ def test_conditional_gradient(): np.testing.assert_allclose(b, G.sum(0)) +def test_conditional_gradient2(): + n = 4000 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([4, 4]) + cov_t = np.array([[1, -.8], [-.8, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) + xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + + a, b = np.ones((n,)) / n, np.ones((n,)) / n + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + def f(G): + return 0.5 * np.sum(G**2) + + def df(G): + return G + + reg = 1e-1 + + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + verbose=True, log=True) + + np.testing.assert_allclose(a, G.sum(1)) + np.testing.assert_allclose(b, G.sum(0)) + + def test_generalized_conditional_gradient(): n_bins = 100 # nb bins -- cgit v1.2.3 From 57321bd0172c97b77dfc8b14972c18d063b6dda8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:13:07 +0100 Subject: add awesome sparse solver --- ot/lp/EMD_wrapper.cpp | 65 ++++++++++++++++++++++++++++++++++++--------------- ot/lp/emd_wrap.pyx | 2 +- test/test_ot.py | 20 ++++++++++++++++ 3 files changed, 67 insertions(+), 20 deletions(-) (limited to 'test') diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 3ca7319..2aa44c1 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -111,23 +111,19 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, long *iG, long *jG, double *G, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + + // Get the number of non zero coordinates for r and c and vectors + int n, m, i, cur; typedef FullBipartiteDigraph Digraph; DIGRAPH_TYPEDEFS(FullBipartiteDigraph); - std::vector indI(n), indJ(m); - std::vector weights1(n), weights2(m); - Digraph di(n, m); - NetworkSimplexSimple net(di, true, n+m, n*m, maxIter); - - // Get the number of non zero coordinates for r and c and vectors + // Get the number of non zero coordinates for r and c n=0; for (int i=0; i0) { - weights1[ n ] = val; - indI[n++]=i; + n++; }else if(val<0){ return INFEASIBLE; } @@ -136,13 +132,41 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, for (int i=0; i0) { - weights2[ m ] = -val; - indJ[m++]=i; + m++; }else if(val<0){ return INFEASIBLE; } } + // Define the graph + + std::vector indI(n), indJ(m); + std::vector weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple net(di, true, n+m, n*m, maxIter); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + // Define the graph net.supplyMap(&weights1[0], n, &weights2[0], m); @@ -166,14 +190,17 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, int i = di.source(a); int j = di.target(a); double flow = net.flow(a); - *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); - - *(G+cur) = flow; - *(iG+cur) = indI[i]; - *(jG+cur) = indJ[j]; - *(alpha + indI[i]) = -net.potential(i); - *(beta + indJ[j-n]) = net.potential(j); - cur++; + if (flow>0) + { + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + + *(G+cur) = flow; + *(iG+cur) = indI[i]; + *(jG+cur) = indJ[j-n]; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + cur++; + } } } diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 345cb66..f183995 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -111,7 +111,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod jG=np.zeros(nmax,dtype=np.int) - result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, G.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, Gv.data, alpha.data, beta.data, &cost, max_iter) return Gv, iG, jG, cost, alpha, beta, result_code diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..4d59e12 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -118,6 +118,26 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_sparse(): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x2 = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x2) + + G = ot.emd([], [], M) + + Gs = ot.emd([], [], M, sparse=True) + + # check G is the same + np.testing.assert_allclose(G, Gs.todense()) + # check constraints + + def test_emd2_multi(): n = 500 # nb bins -- cgit v1.2.3 From a6a654de5e78dd388a793fbd26f60045b05d519c Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:31:32 +0100 Subject: proper documentation and parameter --- ot/lp/EMD.h | 2 +- ot/lp/EMD_wrapper.cpp | 3 ++- ot/lp/__init__.py | 16 ++++++++++++++-- ot/lp/emd_wrap.pyx | 10 ++++++---- test/test_ot.py | 2 +- 5 files changed, 24 insertions(+), 9 deletions(-) (limited to 'test') diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index bc513d2..9896091 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -33,7 +33,7 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 2aa44c1..9be2cdc 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -108,7 +108,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter) { // beware M and C anre strored in row major C style!!! @@ -202,6 +202,7 @@ int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, cur++; } } + *nG=cur; // nb of value +1 for numpy indexing } diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4fec7d9..d476071 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -27,7 +27,7 @@ __all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False, sparse=False): +def emd(a, b, M, numItermax=100000, log=False, dense=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -62,6 +62,10 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): log: bool, optional (default=False) If True, returns a dictionary containing the cost and dual variables. Otherwise returns only the optimal transportation matrix. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -103,6 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + sparse= not dense + # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -128,7 +134,7 @@ def emd(a, b, M, numItermax=100000, log=False, sparse=False): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, sparse=False, return_matrix=False): + numItermax=100000, log=False, dense=True, return_matrix=False): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -166,6 +172,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Returns ------- @@ -207,6 +217,8 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) + sparse=not dense + # problem with pikling Forks if sys.platform.endswith('win32'): processes=1 diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index f183995..4b6cdce 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -21,7 +21,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) int EMD_wrap_return_sparse(int n1, int n2, double *X, double *Y, double *D, - long *iG, long *jG, double *G, + long *iG, long *jG, double *G, long * nG, double* alpha, double* beta, double *cost, int maxIter) cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -75,7 +75,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod max_iter : int The maximum number of iterations before stopping the optimization algorithm if it has not converged. - + sparse : bool + Returning a sparse transport matrix if set to True Returns ------- @@ -87,6 +88,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef int n2= M.shape[1] cdef int nmax=n1+n2-1 cdef int result_code = 0 + cdef int nG=0 cdef double cost=0 cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) @@ -111,10 +113,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod jG=np.zeros(nmax,dtype=np.int) - result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, Gv.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap_return_sparse(n1, n2, a.data, b.data, M.data, iG.data, jG.data, Gv.data, &nG, alpha.data, beta.data, &cost, max_iter) - return Gv, iG, jG, cost, alpha, beta, result_code + return Gv[:nG], iG[:nG], jG[:nG], cost, alpha, beta, result_code else: diff --git a/test/test_ot.py b/test/test_ot.py index 4d59e12..7b44fd1 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -131,7 +131,7 @@ def test_emd_sparse(): G = ot.emd([], [], M) - Gs = ot.emd([], [], M, sparse=True) + Gs = ot.emd([], [], M, dense=False) # check G is the same np.testing.assert_allclose(G, Gs.todense()) -- cgit v1.2.3 From 127adbaf4eef7a6dffbdcd4f930fc6301587f861 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 11:41:13 +0100 Subject: remove useless variable --- test/test_ot.py | 1 - 1 file changed, 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 7b44fd1..8602022 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -125,7 +125,6 @@ def test_emd_sparse(): x = rng.randn(n, 2) x2 = rng.randn(n, 2) - u = ot.utils.unif(n) M = ot.dist(x, x2) -- cgit v1.2.3 From 84384dd9e5dc78ed5cc867a53bd1de31c05d77fc Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:34:05 +0100 Subject: add test emd2 --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 8602022..507d188 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -132,9 +132,12 @@ def test_emd_sparse(): Gs = ot.emd([], [], M, dense=False) + ws = ot.emd2([], [], M, dense=False) + # check G is the same np.testing.assert_allclose(G, Gs.todense()) - # check constraints + # check value + np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6) def test_emd2_multi(): -- cgit v1.2.3 From 7371b2f4f931db8f67ec2967253be8d95ff9fe80 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:34:55 +0100 Subject: add test emd2 --- test/test_ot.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 507d188..48ea87f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -171,7 +171,12 @@ def test_emd2_multi(): emdn = ot.emd2(a, b, M) ot.toc('multi proc : {} s') + ot.tic() + emdn2 = ot.emd2(a, b, M, dense = False) + ot.toc('multi proc : {} s') + np.testing.assert_allclose(emd1, emdn) + np.testing.assert_allclose(emd1, emdn2) # emd loss multipro proc with log ot.tic() -- cgit v1.2.3 From dfaba55affcca606e8e041bdbd0fc5a7735c2b07 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:36:08 +0100 Subject: add test emd2 multi --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 48ea87f..470fd0f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -176,7 +176,7 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) - np.testing.assert_allclose(emd1, emdn2) + np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) # emd loss multipro proc with log ot.tic() -- cgit v1.2.3 From c439e3efb920086154c741b41f65d99165e875d8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 2 Dec 2019 13:57:13 +0100 Subject: pep8 --- test/test_ot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 470fd0f..fbacd8b 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -172,8 +172,8 @@ def test_emd2_multi(): ot.toc('multi proc : {} s') ot.tic() - emdn2 = ot.emd2(a, b, M, dense = False) - ot.toc('multi proc : {} s') + emdn2 = ot.emd2(a, b, M, dense=False) + ot.toc('multi proc : {} s') np.testing.assert_allclose(emd1, emdn) np.testing.assert_allclose(emd1, emdn2, rtol=1e-6) -- cgit v1.2.3 From 92233f79e098f1930248d815e66c0a929508af59 Mon Sep 17 00:00:00 2001 From: Kilian Date: Mon, 9 Dec 2019 15:56:48 +0100 Subject: add assert for emd dimension mismatch --- ot/lp/__init__.py | 6 ++++++ test/test_ot.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) (limited to 'test') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 0c92810..f77c3d7 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -109,6 +109,9 @@ def emd(a, b, M, numItermax=100000, log=False): if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) result_code_string = check_result(result_code) if log: @@ -212,6 +215,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if len(b) == 0: b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + assert (a.shape[0] == M.shape[0] or b.shape[0] == M.shape[1]), \ + "Dimension mismatch, check dimensions of M with a and b" + if log or return_matrix: def f(b): G, cost, u, v, resultCode = emd_c(a, b, M, numItermax) diff --git a/test/test_ot.py b/test/test_ot.py index dacae0a..1343604 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,6 +14,22 @@ from ot.datasets import make_1D_gauss as gauss import pytest +def test_emd_dimension_mismatch(): + # test emd and emd2 for simple identity + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples + 1) + + M = ot.dist(x, x) + + np.testing.assert_raises(AssertionError, emd, a, a, M) + + np.testing.assert_raises(AssertionError, emd2, a, a, M) + + def test_emd_emd2(): # test emd and emd2 for simple identity n = 100 -- cgit v1.2.3 From 428b44e15591071cfcd69af365d878cfd876f9d3 Mon Sep 17 00:00:00 2001 From: Kilian Date: Mon, 9 Dec 2019 16:35:49 +0100 Subject: calling ot.emd test --- test/test_ot.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 1343604..25cdfd4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,8 +14,9 @@ from ot.datasets import make_1D_gauss as gauss import pytest + def test_emd_dimension_mismatch(): - # test emd and emd2 for simple identity + # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 rng = np.random.RandomState(0) @@ -25,9 +26,9 @@ def test_emd_dimension_mismatch(): M = ot.dist(x, x) - np.testing.assert_raises(AssertionError, emd, a, a, M) + np.testing.assert_raises(AssertionError, ot.emd, a, a, M) - np.testing.assert_raises(AssertionError, emd2, a, a, M) + np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) def test_emd_emd2(): -- cgit v1.2.3 From 92dbe259032d340a259209e477e9aac74897689e Mon Sep 17 00:00:00 2001 From: Kilian Date: Mon, 9 Dec 2019 16:43:54 +0100 Subject: pep8 --- test/test_ot.py | 1 - 1 file changed, 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 25cdfd4..42a3d0a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,7 +14,6 @@ from ot.datasets import make_1D_gauss as gauss import pytest - def test_emd_dimension_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 -- cgit v1.2.3 From d97f81dd731c4b1132939500076fd48c89f19d1f Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 18 Dec 2019 10:17:31 +0100 Subject: update test --- test/test_ot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index fbacd8b..3dd544c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -128,7 +128,7 @@ def test_emd_sparse(): M = ot.dist(x, x2) - G = ot.emd([], [], M) + G = ot.emd([], [], M, dense=True) Gs = ot.emd([], [], M, dense=False) -- cgit v1.2.3 From 365adbccc73f7fea28811b16cbbbdbb77761e55c Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Fri, 10 Jan 2020 13:01:42 +0100 Subject: add simple test for screenkhorn --- test/test_bregman.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index f70df10..eb74a9f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -337,3 +337,14 @@ def test_implemented_methods(): ot.bregman.sinkhorn(a, b, M, epsilon, method=method) with pytest.raises(ValueError): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + +def test_screenkhorn(): + # test screenkhorn + rng = np.random.RandomState(0) + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + x = rng.randn(n, 2) + M = ot.dist(x, x) + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-1) \ No newline at end of file -- cgit v1.2.3 From 18242437e73aba9cf131fafc1571e376b57f25f6 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Mon, 13 Jan 2020 09:50:49 +0100 Subject: fix simple test of screenkhorn in test/ --- test/test_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index eb74a9f..bc8f6ae 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -347,4 +347,4 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-1) \ No newline at end of file + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) \ No newline at end of file -- cgit v1.2.3 From 4918d2c619aaa654c524c9c5dc7f4dc82b838f82 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Thu, 16 Jan 2020 16:44:40 +0100 Subject: update readme --- README.md | 2 +- test/test_bregman.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/README.md b/README.md index 987adf1..c115776 100644 --- a/README.md +++ b/README.md @@ -256,4 +256,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). -[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NIPS). +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). diff --git a/test/test_bregman.py b/test/test_bregman.py index bc8f6ae..52e9fb2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -338,6 +338,7 @@ def test_implemented_methods(): with pytest.raises(ValueError): ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + def test_screenkhorn(): # test screenkhorn rng = np.random.RandomState(0) @@ -347,4 +348,4 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) \ No newline at end of file + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) -- cgit v1.2.3 From 936b5e1eb965e1d8c71b7b26cfa5238face1aaa3 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Thu, 16 Jan 2020 17:13:01 +0100 Subject: update --- .idea/POT.iml | 11 +++++++++++ test/test_bregman.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 .idea/POT.iml (limited to 'test') diff --git a/.idea/POT.iml b/.idea/POT.iml new file mode 100644 index 0000000..6711606 --- /dev/null +++ b/.idea/POT.iml @@ -0,0 +1,11 @@ + + + + + + + + + + \ No newline at end of file diff --git a/test/test_bregman.py b/test/test_bregman.py index 52e9fb2..bcec095 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -348,4 +348,4 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) \ No newline at end of file -- cgit v1.2.3 From 3be0c215143e16c59ddd3be902416e91c3292937 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Sat, 18 Jan 2020 07:15:09 +0100 Subject: clean --- examples/plot_screenkhorn_1D.py | 2 +- test/test_bregman.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py index 103d54c..7c0de82 100644 --- a/examples/plot_screenkhorn_1D.py +++ b/examples/plot_screenkhorn_1D.py @@ -59,7 +59,7 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') # ----------------------- # Screenkhorn -lambd = 1e-3 # entropy parameter +lambd = 1e-03 # entropy parameter ns_budget = 30 # budget number of points to be keeped in the source distribution nt_budget = 30 # budget number of points to be keeped in the target distribution diff --git a/test/test_bregman.py b/test/test_bregman.py index bcec095..2398d45 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -348,4 +348,7 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-2, uniform=True, verbose=True) \ No newline at end of file + G_sink = ot.sinkhorn(a, b, M, 1e-03) + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) + np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) \ No newline at end of file -- cgit v1.2.3 From b3fb1ef40a482f0989686b79373060d764b62d38 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Sat, 18 Jan 2020 07:45:34 +0100 Subject: clean --- ot/bregman.py | 3 ++- test/test_bregman.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/ot/bregman.py b/ot/bregman.py index aff9f8c..c304b5d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2117,10 +2117,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res log['v'] = vsc_full log['Isel'] = Isel log['Jsel'] = Jsel + gamma = usc_full[:, None] * K * vsc_full[None, :] gamma = gamma / gamma.sum() if log: return gamma, log else: - return gamma \ No newline at end of file + return gamma diff --git a/test/test_bregman.py b/test/test_bregman.py index 2398d45..e376715 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -348,7 +348,7 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_sink = ot.sinkhorn(a, b, M, 1e-03) - G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) - np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) - np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) \ No newline at end of file + G_s = ot.sinkhorn(a, b, M, 1e-03) + G_sc = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + np.testing.assert_allclose(G_s.sum(0), G_sc.sum(0), atol=1e-02) + np.testing.assert_allclose(G_s.sum(1), G_sc.sum(1), atol=1e-02) \ No newline at end of file -- cgit v1.2.3 From 7f7b1c547b54b394db975f4ff9d0287904a7b820 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Sat, 18 Jan 2020 09:04:48 +0100 Subject: make autopep --- test/test_bregman.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index e376715..fd0679b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -106,7 +106,6 @@ def test_sinkhorn_variants_log(): @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) def test_barycenter(method): - n_bins = 100 # nb bins # Gaussian distributions @@ -133,7 +132,6 @@ def test_barycenter(method): def test_barycenter_stabilization(): - n_bins = 100 # nb bins # Gaussian distributions @@ -161,7 +159,6 @@ def test_barycenter_stabilization(): def test_wasserstein_bary_2d(): - size = 100 # size of a square image a1 = np.random.randn(size, size) a1 += a1.min() @@ -185,7 +182,6 @@ def test_wasserstein_bary_2d(): def test_unmix(): - n_bins = 50 # nb bins # Gaussian distributions @@ -207,7 +203,7 @@ def test_unmix(): # wasserstein reg = 1e-3 - um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, ) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) @@ -256,7 +252,7 @@ def test_empirical_sinkhorn(): def test_empirical_sinkhorn_divergence(): - #Test sinkhorn divergence + # Test sinkhorn divergence n = 10 a = ot.unif(n) b = ot.unif(n) @@ -348,7 +344,10 @@ def test_screenkhorn(): x = rng.randn(n, 2) M = ot.dist(x, x) - G_s = ot.sinkhorn(a, b, M, 1e-03) - G_sc = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) - np.testing.assert_allclose(G_s.sum(0), G_sc.sum(0), atol=1e-02) - np.testing.assert_allclose(G_s.sum(1), G_sc.sum(1), atol=1e-02) \ No newline at end of file + # sinkhorn + G_sink = ot.sinkhorn(a, b, M, 1e-03) + # screenkhorn + G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) + # check marginals + np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) + np.testing.assert_allclose(G_s.sum(1), G_screen.sum(1), atol=1e-02) -- cgit v1.2.3 From a1747a10e80751eacca4273af61083a853fb9dd4 Mon Sep 17 00:00:00 2001 From: "Mokhtar Z. Alaya" Date: Sat, 18 Jan 2020 09:12:55 +0100 Subject: make autopep --- test/test_bregman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index fd0679b..f54ba9f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -350,4 +350,4 @@ def test_screenkhorn(): G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True) # check marginals np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) - np.testing.assert_allclose(G_s.sum(1), G_screen.sum(1), atol=1e-02) + np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) -- cgit v1.2.3 From 3844639d3dd4e0dd360ebef34dd657d26664039e Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 27 Jan 2020 09:35:19 +0100 Subject: add test for constraint viuolation of duals --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index 18b6294..c756e51 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -337,7 +337,10 @@ def test_dual_variables(): # Check that both cost computations are equivalent np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - + + viol=log['u'][:,None]+log['v'][None,:]-M + + assert viol.max()<1e-8 def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) -- cgit v1.2.3 From 30fc233f7f62d571a562971a945d68c3782f0780 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 27 Jan 2020 09:37:42 +0100 Subject: correct pep8 --- test/test_ot.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/test/test_ot.py b/test/test_ot.py index c756e51..245a107 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -337,10 +337,11 @@ def test_dual_variables(): # Check that both cost computations are equivalent np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - - viol=log['u'][:,None]+log['v'][None,:]-M - - assert viol.max()<1e-8 + + viol = log['u'][:, None] + log['v'][None, :] - M + + assert viol.max() < 1e-8 + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) -- cgit v1.2.3 From f65073faa73b36280a19ff8b9c383e66f8bdbd2b Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 30 Jan 2020 08:04:36 +0100 Subject: comlete documentation --- ot/lp/__init__.py | 30 +++++++++++++++++++----------- ot/lp/emd_wrap.pyx | 6 ++++++ test/test_ot.py | 4 ++-- 3 files changed, 27 insertions(+), 13 deletions(-) (limited to 'test') diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index aa3166f..cdd505d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -28,10 +28,10 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials wrt theirs weights + r"""Center dual OT potentials w.r.t. theirs weights The main idea of this function is to find unique dual potentials - that ensure some kind of centering/fairness. It will help have + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having stability when multiple calling of the OT solver with small changes. Basically we add another constraint to the potential that will not @@ -91,7 +91,15 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): def estimate_dual_null_weights(alpha0, beta0, a, b, M): r"""Estimate feasible values for 0-weighted dual potentials - The feasible values are computed efficiently bjt rather coarsely. + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in emd_c + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + First we compute the constraints violations: .. math:: @@ -113,11 +121,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 - In the end the dual potential are centred using function + In the end the dual potentials are centered using function :ref:`center_ot_dual`. Note that all those updates do not change the objective value of the - solution but provide dual potential that do not violate the constraints. + solution but provide dual potentials that do not violate the constraints. Parameters ---------- @@ -130,9 +138,9 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): beta0 : (nt,) numpy.ndarray, float64 Target dual potential a : (ns,) numpy.ndarray, float64 - Source histogram (uniform weight if empty list) + Source distribution (uniform weights if empty list) b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) + Target distribution (uniform weights if empty list) M : (ns,nt) numpy.ndarray, float64 Loss matrix (c-order array with type float64) @@ -150,11 +158,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): bsel = b != 0 # compute dual constraints violation - Viol = alpha0[:, None] + beta0[None, :] - M + constraint_violation = alpha0[:, None] + beta0[None, :] - M - # Compute worst violation per line and columns - aviol = np.max(Viol, 1) - bviol = np.max(Viol, 0) + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) # update corrects violation of alpha_up = -1 * ~asel * np.maximum(aviol, 0) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index a4987f4..d345fd4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -66,6 +66,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod .. warning:: Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + Parameters ---------- a : (ns,) numpy.ndarray, float64 diff --git a/test/test_ot.py b/test/test_ot.py index 245a107..47df946 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -338,9 +338,9 @@ def test_dual_variables(): np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - viol = log['u'][:, None] + log['v'][None, :] - M + constraint_violation = log['u'][:, None] + log['v'][None, :] - M - assert viol.max() < 1e-8 + assert constraint_violation.max() < 1e-8 def check_duality_gap(a, b, M, G, u, v, cost): -- cgit v1.2.3