From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: Rémi Flamary --- test/test_unbalanced.py | 157 +++++++++++++++++++++++++++++------------------- 1 file changed, 96 insertions(+), 61 deletions(-) (limited to 'test/test_unbalanced.py') diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e8349d1..db59504 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,11 +9,9 @@ import ot import pytest from ot.unbalanced import barycenter_unbalanced -from scipy.special import logsumexp - @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(method): +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -28,36 +26,51 @@ def test_unbalanced_convergence(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method, - verbose=True) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method=method, verbose=True + )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16) + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss - np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_multiple_inputs(method): +def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, @@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method): # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): +def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn n = 100 @@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn(): M /= np.median(M) epsilon = 0.1 reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - method="sinkhorn_stabilized", - reg_m=reg_m, - log=True, - verbose=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method="sinkhorn", log=True) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ) + G2, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ) + G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method="sinkhorn", log=True + ) + G = nx.to_numpy(G) + G2 = nx.to_numpy(G2) np.testing.assert_allclose(G, G2, atol=1e-5) + np.testing.assert_allclose(G2, G2_np, atol=1e-5) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_barycenter(method): +def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method): epsilon = 1. reg_m = 1. - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True) + A, M = nx.from_numpy(A, M) + + q, log = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations fi = reg_m / (reg_m + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logA = nx.log(A + 1e-16) + logq = nx.log(q + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) -def test_barycenter_stabilized_vs_sinkhorn(): +def test_barycenter_stabilized_vs_sinkhorn(nx): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn(): epsilon = 0.5 reg_m = 10 - qstable, log = barycenter_unbalanced(A, M, reg=epsilon, - reg_m=reg_m, log=True, - tau=100, - method="sinkhorn_stabilized", - verbose=True - ) - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method="sinkhorn", - log=True) + Ab, Mb = nx.from_numpy(A, M) - np.testing.assert_allclose( - q, qstable, atol=1e-05) + qstable, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, + method="sinkhorn_stabilized", verbose=True + ) + q, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q_np, _ = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q, qstable = nx.to_numpy(q, qstable) + np.testing.assert_allclose(q, qstable, atol=1e-05) + np.testing.assert_allclose(q, q_np, atol=1e-05) -def test_wrong_method(): +def test_wrong_method(nx): n = 10 rng = np.random.RandomState(42) @@ -199,19 +230,20 @@ def test_wrong_method(): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method='badmethod', - log=True, - verbose=True) + ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', + log=True, verbose=True + ) with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method='badmethod', - verbose=True) + ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method='badmethod', verbose=True + ) -def test_implemented_methods(): +def test_implemented_methods(nx): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] @@ -228,6 +260,9 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. reg_m = 1. + + a, b, M, A = nx.from_numpy(a, b, M, A) + for method in IMPLEMENTED_METHODS: ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) -- cgit v1.2.3 From ac4cf442735ed4c0d5405ad861eddaa02afd4edd Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Mon, 11 Apr 2022 15:38:18 +0200 Subject: [MRG] MM algorithms for UOT (#362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: Rémi Flamary --- README.md | 6 +- RELEASES.md | 1 + docs/source/all.rst | 1 + examples/unbalanced-partial/plot_unbalanced_OT.py | 116 +++++ ot/partial.py | 84 +++- ot/regpath.py | 545 ++++++++++++++-------- ot/unbalanced.py | 223 +++++++++ test/test_unbalanced.py | 50 ++ 8 files changed, 802 insertions(+), 224 deletions(-) create mode 100644 examples/unbalanced-partial/plot_unbalanced_OT.py (limited to 'test/test_unbalanced.py') diff --git a/README.md b/README.md index 2ace69c..1b50aeb 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples): Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. @@ -309,4 +309,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. -[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file +[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. + +[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index b54a84a..7942a15 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -19,6 +19,7 @@ - Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) +- Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 3f7d029..1ec6be3 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -26,6 +26,7 @@ API and modules plot stochastic unbalanced + regpath partial sliced weak diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py new file mode 100644 index 0000000..03487e7 --- /dev/null +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +============================================================== +2D examples of exact and entropic unbalanced optimal transport +============================================================== +This example is designed to show how to compute unbalanced and +partial OT in POT. + +UOT aims at solving the following optimization problem: + + .. math:: + W = \min_{\gamma} <\gamma, \mathbf{M}>_F + + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + +where :math:`\mathrm{div}` is a divergence. +When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` +should be the Kullback-Leibler divergence. +When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` +can be either the Kullback-Leibler or the quadratic divergence. +Using :math:`\ell_1` norm gives the so-called partial OT. +""" + +# Author: Laetitia Chapel +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 40 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +n_noise = 10 + +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + +n = n + n_noise + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + + +############################################################################## +# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT +# ----------- + +reg = 0.005 +reg_m_kl = 0.05 +reg_m_l2 = 5 +mass = 0.7 + +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) + +############################################################################## +# Plot the results +# ---------------- + +pl.figure(2) +transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] +title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + + str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] + +for p in range(4): + pl.subplot(2, 4, p + 1) + P = transp[p] + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.title(title[p]) + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("mappings") + pl.subplot(2, 4, p + 5) + pl.imshow(P, cmap='jet') + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("transport plans") +pl.show() diff --git a/ot/partial.py b/ot/partial.py index b7093e4..0a9e450 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -7,7 +7,6 @@ Partial OT solvers # License: MIT License import numpy as np - from .lp import emd @@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m & + \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. @@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining a given mass to be transported `m` - The formulation of the problem has been proposed in :ref:`[28] ` + The formulation of the problem has been proposed in + :ref:`[28] ` Parameters @@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] ` + The formulation of the problem has been proposed in + :ref:`[29] ` Parameters @@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] ` + The formulation of the problem has been proposed in + :ref:`[29] ` Parameters @@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, The function considers the following problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, + \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ \gamma^T \mathbf{1} &\leq \mathbf{b} \\ \gamma &\geq 0 \\ - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[3] ` (prop. 5) + The formulation of the problem has been proposed in + :ref:`[3] ` (prop. 5) Parameters @@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 + q1 = np.ones(K.shape) + q2 = np.ones(K.shape) + q3 = np.ones(K.shape) while (err > stopThr and cpt < numItermax): Kprev = K + K = K * q1 K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + q1 = q1 * Kprev / K1 + K1prev = K1 + K1 = K1 * q2 K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + q2 = q2 * K1prev / K2 + K2prev = K2 + K2 = K2 * q3 K = K2 * (m / np.sum(K2)) + q3 = q3 * K2prev / K if np.any(np.isnan(K)) or np.any(np.isinf(K)): print('Warning: numerical errors at iteration', cpt) @@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: @@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, \gamma^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` Parameters ---------- @@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) .. math:: s.t. \ \gamma &\geq 0 @@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L` : quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` Parameters @@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional diff --git a/ot/regpath.py b/ot/regpath.py index 269937a..e745288 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -11,34 +11,48 @@ import scipy.sparse as sp def recast_ot_as_lasso(a, b, C): - r"""This function recasts the l2-penalized UOT problem as a Lasso problem + r"""This function recasts the l2-penalized UOT problem as a Lasso problem. + + Recall the l2-penalized UOT problem defined in + :ref:`[41] ` - Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] .. math:: - UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + - \lambda \|T^T 1_n - b\|_2^2 + \text{UOT}_{\lambda} = \min_T + \lambda \|T 1_m - + \mathbf{a}\|_2^2 + + \lambda \|T^T 1_n - \mathbf{b}\|_2^2 + s.t. T \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize - The problem above can be reformulated to a non-negative penalized + - :math:`C` is the cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize + + The problem above can be reformulated as a non-negative penalized linear regression problem, particularly Lasso + .. math:: - UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T + \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product H t - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C` + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}` + - :math:`H` is a metric matrix, see :ref:`[41] ` for \ + the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- H : np.ndarray (dim_a+dim_b, dim_a*dim_b) - Auxiliary matrix constituted by 0 and 1 + Design matrix that contains only 0 and 1 y : np.ndarray (ns + nt, ) - Concatenation of histogram a and histogram b + Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C): >>> c array([16., 25., 28., 16., 40., 36.]) + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ dim_a = np.shape(a)[0] @@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C): def recast_semi_relaxed_as_lasso(a, b, C): - r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem. .. math:: - semi-relaxed UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + + \text{semi-relaxed UOT} = \min_T + + \lambda \|T 1_m - \mathbf{a}\|_2^2 + s.t. - T^T 1_n = b - t \geq 0 + T^T 1_n = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize + + - :math:`C` is the metric cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize The problem above can be reformulated as follows + .. math:: - semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a metric matrix which computes the sum along the \ + rows of the transport plan :math:`T` + - :math:`H_c` is a metric matrix which computes the sum along the \ + columns of the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- Hr : np.ndarray (dim_a, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the rows of transport plan T + the sum along the rows of transport plan :math:`T` Hc : np.ndarray (dim_b, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the columns of transport plan T + the sum along the columns of transport plan :math:`T` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C): def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): r""" This function computes the next value of gamma if a variable - will be added in next iteration of the regularization path + is added in the next iteration of the regularization path. We look for the largest value of gamma such that the gradient of an inactive variable vanishes + .. math:: - \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})} + {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i} + where : + - A is the current active set - - h_i is the ith column of auxiliary matrix H - - H_A is the sub-matrix constructed by the columns of H - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution - - :math:`\phi` is the intercept of the solutions in current iteration - - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \ + matrix :math:`{H}` + - :math:`{H}_A` is the sub-matrix constructed by the columns of \ + :math:`{H}` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \ + :math:`\mathbf{c}` + - :math:`\mathbf{y}` is the concatenation of the source and target \ + distributions + - :math:`\phi` is the intercept of the solutions at the current iteration + - :math:`\delta` is the slope of the solutions at the current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H^T H + Matrix product of :math:`{H}^T {H}` Hty : np.ndarray (dim_a + dim_b, ) - Matrix product of H^T y + Matrix product of :math:`{H}^T \mathbf{y}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of the cost matrix :math:`{C}` active_index : list Indices of active variables current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the current \ + iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HtH[:, active_index].dot(phi) - Hty) / \ (HtH[:, active_index].dot(delta) - c + 1e-16) @@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, By taking the Lagrangian form of the problem, we obtain a similar update as the two-sided relaxed UOT + .. math:: - \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T - \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + + \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a}) + + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \ + \mathbf{h}_{c i} \delta_u - \mathbf{c}_i} + where : + - A is the current active set - - h_{r i} is the ith column of the matrix H_r - - h_{c i} is the ith column of the matrix H_c - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution + - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r` + - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c` + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \ + :math:`\mathbf{c}` - :math:`\phi` is the intercept of the solutions in current iteration - :math:`\delta` is the slope of the solutions in current iteration - - :math:`\phi_u` is the intercept of Lagrange parameter in current - iteration - - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter at the \ + current iteration + - :math:`\delta_u` is the slope of Lagrange parameter at the \ + current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration phi_u : np.ndarray (dim_b, ) - Intercept of the Lagrange parameter in current iteration (also linear) + Intercept of the Lagrange parameter at the current iteration delta_u : np.ndarray (dim_b, ) - Slope of the Lagrange parameter in current iteration (also linear) + Slope of the Lagrange parameter at the current iteration HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` Hra : np.ndarray (dim_a * dim_b, ) - Matrix product of H_r^T a + Matrix product of :math:`H_r^T \mathbf{a}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of cost matrix :math:`C` active_index : list Indices of active variables current_gamma : float Value of regularization coefficient at the start of current iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ @@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, def compute_next_removal(phi, delta, current_gamma): - r""" This function computes the next value of gamma if a variable - is removed in next iteration of regularization path + r""" This function computes the next gamma value if a variable + is removed at the next iteration of the regularization path. + + We look for the largest value of the regularization parameter such that + an element of the current solution vanishes - We look for the largest value of gamma such that - an element of current solution vanishes .. math:: \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set - - phi_j is the jth element of the intercept of current solution - - delta_j is the jth elemnt of the slope of current solution + - :math:`\phi_j` is the :math:`j` th element of the intercept of the \ + current solution + - :math:`\delta_j` is the :math:`j` th element of the slope of the \ + current solution + + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : ndarray, shape (size(A), ) + Intercept of the solution at the current iteration + delta : ndarray, shape (size(A), ) + Slope of the solution at the current iteration current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the \ + current iteration + Returns ------- next_removal_gamma : float - Value of gamma if a variable is removed in next iteration + Gamma value if a variable is removed at the next iteration next_removal_index : int - Index of the variable to remove in next iteration + Index of the variable to be removed at the next iteration + + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ r_candidate = phi / (delta - 1e-16) r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 @@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma): def complement_schur(M_current, b, d, id_pop): - r""" This function computes the inverse of matrix in regularization path - using Schur complement + r""" This function computes the inverse of the design matrix in the \ + regularization path using the Schur complement. Two cases may arise: + + Case 1: one variable is added to the active set + - Two cases may arise: Firstly one variable is added to the active set .. math:: M_{k+1}^{-1} = \begin{bmatrix} - M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ - - s^{-1} b^T M_{k}^{-1} & s^{-1} + M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \ + & - M_{k}^{-1} \mathbf{b} s^{-1} \\ + - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1} \end{bmatrix} + + where : - - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and - :math:`M_k` is the upper left block matrix in Schur formulation - - b is the upper right block matrix in Schur formulation. In our case, - b is reduced to a column vector and b^T is the lower left block matrix - - s is the Schur complement, given by - :math:`s = d - b^T M_{k}^{-1} b` in our case - - Secondly, one variable is removed from the active set + + - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \ + of the previous iteration + - :math:`\mathbf{b}` is the last column of :math:`M_{k}` + - :math:`s` is the Schur complement, given by \ + :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}` + + Case 2: one variable is removed from the active set. + .. math:: - M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + M_{k+1}^{-1} = M^{-1}_{k \backslash q} - \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : - - q is the index of column and row to delete - - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix - without qth column and qth row - - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element - - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + + - :math:`q` is the index of column and row to delete + - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \ + of the :math:`q` th column and :math:`q` th row + - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \ + without the :math:`q` th element + - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \ + row in :math:`M^{-1}_{k}` + + Parameters ---------- - M_current : np.ndarray (|A|-1, |A|-1) - Inverse matrix in previous iteration - b : np.ndarray (|A|-1, ) - Upper right matrix in Schur complement, a column vector in our case + M_current : ndarray, shape (size(A)-1, size(A)-1) + Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \ + size(A) the size of the active set + b : ndarray, shape (size(A)-1, ) + None for case 2 (removal), last column of :math:`M_{k}` for case 1 \ + (addition) d : float - Lower right matrix in Schur complement, a scalar in our case - id_pop + should be equal to 2 when UOT and 1 for the semi-relaxed OT + id_pop : int Index of the variable to be removed, equal to -1 - if none of the variables is deleted in current iteration + if no variable is deleted at the current iteration + + Returns ------- - M : np.ndarray (|A|, |A|) - Inverse matrix needed in current iteration + M : ndarray, shape (size(A), size(A)) + Inverse matrix of :math:`H_A^tH_A` of the current iteration + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ + if b is None: b = M_current[id_pop, :] b = np.delete(b, id_pop) @@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop): def construct_augmented_H(active_index, m, Hc, HrHr): - r""" This function construct an augmented matrix for the first iteration of - semi-relaxed regularization path + r""" This function constructs an augmented matrix for the first iteration + of the semi-relaxed regularization path .. math:: - Augmented_H = + \text{Augmented}_H = \begin{bmatrix} 0 & H_{c A} \\ H_{c A}^T & H_{r A}^T H_{r A} \end{bmatrix} + where : - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - H_{c A} is the sub-matrix constructed by the columns of H_c - whose indices belong to the active set A + + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`H_{c A}` is the sub-matrix constructed by the columns of \ + :math:`H_c` whose indices belong to the active set A + + Parameters ---------- active_index : list - Indices of active variables + Indices of the active variables m : int Length of the target distribution Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` + Returns ------- - H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A)) Augmented matrix for the first iteration of the semi-relaxed regularization path """ @@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, r"""This function gives the regularization path of l2-penalized UOT problem The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product Ht - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] ` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- a : np.ndarray (dim_a,) @@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the optimal transport matrix t_list : list - List of solutions in regularization path + List of solutions in the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of regularization coefficients in the regularization path + Examples -------- >>> import ot @@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, itmax=50000): r"""This function gives the regularization path of semi-relaxed - l2-UOT problem + l2-UOT problem. The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + + \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a matrix that computes the sum along the rows of \ + the transport plan :math:`T` + - :math:`H_c` is a matrix that computes the sum along the columns of \ + the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, l2-regularization coefficient itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + Examples -------- >>> import ot @@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, semi_relaxed=False, itmax=50000): - r"""This function combines both the semi-relaxed and the fully-relaxed - regularization paths of l2-UOT problem + r"""This function provides all the solutions of the regularization path \ + of the l2-UOT problem :ref:`[41] `. + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + + .. math:: + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + + s.t. + \mathbf{t} \geq 0 + + where : + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] ` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + + For the semi-relaxed problem, it optimizes the Lasso reformulation of the + l2-penalized UOT: + + .. math:: + + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + + s.t. + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + Parameters ---------- @@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, reg: float (optional) l2-regularization coefficient semi_relaxed : bool (optional) - Give the semi-relaxed path if true + Give the semi-relaxed path if True itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if semi_relaxed: t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, @@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def compute_transport_plan(gamma, gamma_list, Pi_list): r""" Given the regularization path, this function computes the transport - plan for any value of gamma by the piecewise linearity of the path + plan for any value of gamma thanks to the piecewise linearity of the path. .. math:: t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) - where : - - :math:`\gamma` is the regularization coefficient + + where: + + - :math:`\gamma` is the regularization parameter - :math:`\phi(\gamma)` is the corresponding intercept - :math:`\delta(\gamma)` is the corresponding slope - - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- gamma : float Regularization coefficient gamma_list : list - List of regularization coefficients in regularization path + List of regularization parameters of the regularization path Pi_list : list - List of solutions in regularization path + List of all the solutions of the regularization path + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Transport vector corresponding to the given value of gamma + Vectorization of the transport plan corresponding to the given value + of gamma + Examples -------- >>> import ot @@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list): array([0. , 0. , 0. , 0.19722222, 0.05555556, 0. , 0. , 0.24722222, 0. ]) + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if gamma >= gamma_list[0]: diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 503cc1e..90c920c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -4,6 +4,7 @@ Regularized Unbalanced OT solvers """ # Author: Hicham Janati +# Laetitia Chapel # License: MIT License from __future__ import division @@ -1029,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) + + +def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2) + array([[0.3 , 0. ], + [0. , 0.07]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2) + array([[0.25, 0. ], + [0. , 0. ]]) + + + .. _references-regpath: + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd : Unregularized OT + ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT + """ + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if G0 is None: + G = a[:, None] * b[None, :] + else: + G = G0 + + if log: + log = {'err': [], 'G': []} + + if div == 'kl': + K = nx.exp(M / - reg_m / 2) + elif div == 'l2': + K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2, + nx.zeros((dim_a, dim_b), type_as=M)) + else: + warnings.warn("The div parameter should be either equal to 'kl' or \ + 'l2': it has been set to 'kl'.") + div = 'kl' + K = nx.exp(M / - reg_m / 2) + + for i in range(numItermax): + Gprev = G + + if div == 'kl': + u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16)) + v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16)) + G = G * K * u[:, None] * v[None, :] + elif div == 'l2': + Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16 + G = G * K / Gd + + err = nx.sqrt(nx.sum((G - Gprev) ** 2)) + if log: + log['err'].append(err) + log['G'].append(G) + if verbose: + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break + + if log: + log['cost'] = nx.sum(G * M) + return G, log + else: + return G + + +def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=True) + + if log: + return log_mm['cost'], log_mm + else: + return log_mm['cost'] diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index db59504..02b3fc3 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -1,6 +1,7 @@ """Tests for module Unbalanced OT with entropy regularization""" # Author: Hicham Janati +# Laetitia Chapel # # License: MIT License @@ -286,3 +287,52 @@ def test_implemented_methods(nx): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + + +def test_mm_convergence(nx): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a, b, M) + + G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', + verbose=True, log=True) + loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m, div='kl', verbose=True)) + G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', + verbose=False, log=True) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, + atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') + np.testing.assert_allclose(G_kl_null, G_kl) + np.testing.assert_allclose(G_l2_null, G_l2) + + # test when G0 is given + G0 = ot.emd(a, b, M) + reg_m = 10000 + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) + np.testing.assert_allclose(G0, G_kl, atol=1e-05) + np.testing.assert_allclose(G0, G_l2, atol=1e-05) -- cgit v1.2.3