summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaetitia Chapel <laetitia.chapel@univ-ubs.fr>2022-04-11 15:38:18 +0200
committerGitHub <noreply@github.com>2022-04-11 15:38:18 +0200
commitac4cf442735ed4c0d5405ad861eddaa02afd4edd (patch)
tree6f0bf54ca7452621bc55548f2a2a2615b8975b54
parent0b223ff883fd73601984a92c31cb70d4aded16e8 (diff)
[MRG] MM algorithms for UOT (#362)
* bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--README.md6
-rw-r--r--RELEASES.md1
-rw-r--r--docs/source/all.rst1
-rw-r--r--examples/unbalanced-partial/plot_unbalanced_OT.py116
-rwxr-xr-xot/partial.py84
-rw-r--r--ot/regpath.py545
-rw-r--r--ot/unbalanced.py223
-rw-r--r--test/test_unbalanced.py50
8 files changed, 802 insertions, 224 deletions
diff --git a/README.md b/README.md
index 2ace69c..1b50aeb 100644
--- a/README.md
+++ b/README.md
@@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples):
Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
-* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
+* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
@@ -309,4 +309,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
-[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file
+[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
+
+[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index b54a84a..7942a15 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -19,6 +19,7 @@
- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343).
- Add (F)GW linear dictionary learning solvers + example (PR #319)
- Add links to related PR and Issues in the doc release page (PR #350)
+- Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362)
#### Closed issues
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 3f7d029..1ec6be3 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -26,6 +26,7 @@ API and modules
plot
stochastic
unbalanced
+ regpath
partial
sliced
weak
diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py
new file mode 100644
index 0000000..03487e7
--- /dev/null
+++ b/examples/unbalanced-partial/plot_unbalanced_OT.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+"""
+==============================================================
+2D examples of exact and entropic unbalanced optimal transport
+==============================================================
+This example is designed to show how to compute unbalanced and
+partial OT in POT.
+
+UOT aims at solving the following optimization problem:
+
+ .. math::
+ W = \min_{\gamma} <\gamma, \mathbf{M}>_F +
+ \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+where :math:`\mathrm{div}` is a divergence.
+When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}`
+should be the Kullback-Leibler divergence.
+When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}`
+can be either the Kullback-Leibler or the quadratic divergence.
+Using :math:`\ell_1` norm gives the so-called partial OT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 40 # nb samples
+
+mu_s = np.array([-1, -1])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+np.random.seed(0)
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+n_noise = 10
+
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)
+
+n = n + n_noise
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+
+##############################################################################
+# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
+# -----------
+
+reg = 0.005
+reg_m_kl = 0.05
+reg_m_l2 = 5
+mass = 0.7
+
+entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
+kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl')
+l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2')
+partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
+
+##############################################################################
+# Plot the results
+# ----------------
+
+pl.figure(2)
+transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
+title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" +
+ str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
+ "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)]
+
+for p in range(4):
+ pl.subplot(2, 4, p + 1)
+ P = transp[p]
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.3)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2)
+ pl.title(title[p])
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("mappings")
+ pl.subplot(2, 4, p + 5)
+ pl.imshow(P, cmap='jet')
+ pl.yticks(())
+ pl.xticks(())
+ if p < 1:
+ pl.ylabel("transport plans")
+pl.show()
diff --git a/ot/partial.py b/ot/partial.py
index b7093e4..0a9e450 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -7,7 +7,6 @@ Partial OT solvers
# License: MIT License
import numpy as np
-
from .lp import emd
@@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &
+ \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X.
@@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
- :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining
a given mass to be transported `m`
- The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>`
+ The formulation of the problem has been proposed in
+ :ref:`[28] <references-partial-wasserstein-lagrange>`
Parameters
@@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein>`
Parameters
@@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
\gamma &\geq 0
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>`
+ The formulation of the problem has been proposed in
+ :ref:`[29] <references-partial-gromov-wasserstein2>`
Parameters
@@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported
+ (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
The function considers the following problem:
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma,
+ \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
\gamma^T \mathbf{1} &\leq \mathbf{b} \\
\gamma &\geq 0 \\
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
where :
- :math:`\mathbf{M}` is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
+ The formulation of the problem has been proposed in
+ :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
Parameters
@@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
np.multiply(K, m / np.sum(K), out=K)
err, cpt = 1, 0
+ q1 = np.ones(K.shape)
+ q2 = np.ones(K.shape)
+ q3 = np.ones(K.shape)
while (err > stopThr and cpt < numItermax):
Kprev = K
+ K = K * q1
K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ q1 = q1 * Kprev / K1
+ K1prev = K1
+ K1 = K1 * q2
K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ q2 = q2 * K1prev / K2
+ K2prev = K2
+ K2 = K2 * q3
K = K2 * (m / np.sum(K2))
+ q3 = q3 * K2prev / K
if np.any(np.isnan(K)) or np.any(np.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
@@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein transport between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
@@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
\gamma^T \mathbf{1} &\leq \mathbf{b}
- \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
+ \mathbf{1}^T \gamma^T \mathbf{1} = m
+ &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
@@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L`: quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
Parameters
----------
@@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the partial Gromov-Wasserstein discrepancy between
+ :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
+ GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k},
+ \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
.. math::
s.t. \ \gamma &\geq 0
@@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
- `L` : quadratic loss function
- - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term,
+ :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
+ The formulation of the GW problem has been proposed in
+ :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the
+ partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
Parameters
@@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
+ Amount of mass to be transported (default:
+ :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
diff --git a/ot/regpath.py b/ot/regpath.py
index 269937a..e745288 100644
--- a/ot/regpath.py
+++ b/ot/regpath.py
@@ -11,34 +11,48 @@ import scipy.sparse as sp
def recast_ot_as_lasso(a, b, C):
- r"""This function recasts the l2-penalized UOT problem as a Lasso problem
+ r"""This function recasts the l2-penalized UOT problem as a Lasso problem.
+
+ Recall the l2-penalized UOT problem defined in
+ :ref:`[41] <references-regpath>`
- Recall the l2-penalized UOT problem defined in [Chapel et al., 2021]
.. math::
- UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 +
- \lambda \|T^T 1_n - b\|_2^2
+ \text{UOT}_{\lambda} = \min_T <C, T> + \lambda \|T 1_m -
+ \mathbf{a}\|_2^2 +
+ \lambda \|T^T 1_n - \mathbf{b}\|_2^2
+
s.t.
T \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
- The problem above can be reformulated to a non-negative penalized
+ - :math:`C` is the cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
+
+ The problem above can be reformulated as a non-negative penalized
linear regression problem, particularly Lasso
+
.. math::
- UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T
+ \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product H t
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C`
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`
+ - :math:`H` is a metric matrix, see :ref:`[41] <references-regpath>` for \
+ the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
H : np.ndarray (dim_a+dim_b, dim_a*dim_b)
- Auxiliary matrix constituted by 0 and 1
+ Design matrix that contains only 0 and 1
y : np.ndarray (ns + nt, )
- Concatenation of histogram a and histogram b
+ Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C):
>>> c
array([16., 25., 28., 16., 40., 36.])
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
dim_a = np.shape(a)[0]
@@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C):
def recast_semi_relaxed_as_lasso(a, b, C):
- r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem
+ r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem.
.. math::
- semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2
+
+ \text{semi-relaxed UOT} = \min_T <C, T>
+ + \lambda \|T 1_m - \mathbf{a}\|_2^2
+
s.t.
- T^T 1_n = b
- t \geq 0
+ T^T 1_n = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - C is the (dim_a, dim_b) metric cost matrix
- - :math:`\lambda` is the l2-regularization coefficient
- - a and b are source and target distributions
- - T is the transport plan to optimize
+
+ - :math:`C` is the metric cost matrix
+ - :math:`\lambda` is the l2-regularization parameter
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \
+ distributions
+ - :math:`T` is the transport plan to optimize
The problem above can be reformulated as follows
+
.. math::
- semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a metric matrix which computes the sum along the \
+ rows of the transport plan :math:`T`
+ - :math:`H_c` is a metric matrix which computes the sum along the \
+ columns of the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C):
Histogram of dimension dim_b
C : np.ndarray, shape (dim_a, dim_b)
Cost matrix
+
Returns
-------
Hr : np.ndarray (dim_a, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the rows of transport plan T
+ the sum along the rows of transport plan :math:`T`
Hc : np.ndarray (dim_b, dim_a * dim_b)
Auxiliary matrix constituted by 0 and 1, which computes
- the sum along the columns of transport plan T
+ the sum along the columns of transport plan :math:`T`
c : np.ndarray (ns * nt, )
- Flattened array of cost matrix
+ Flattened array of the cost matrix
+
Examples
--------
>>> import ot
@@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C):
def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
r""" This function computes the next value of gamma if a variable
- will be added in next iteration of the regularization path
+ is added in the next iteration of the regularization path.
We look for the largest value of gamma such that
the gradient of an inactive variable vanishes
+
.. math::
- \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i}
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})}
+ {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_i is the ith column of auxiliary matrix H
- - H_A is the sub-matrix constructed by the columns of H
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
- - :math:`\phi` is the intercept of the solutions in current iteration
- - :math:`\delta` is the slope of the solutions in current iteration
+ - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \
+ matrix :math:`{H}`
+ - :math:`{H}_A` is the sub-matrix constructed by the columns of \
+ :math:`{H}` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \
+ :math:`\mathbf{c}`
+ - :math:`\mathbf{y}` is the concatenation of the source and target \
+ distributions
+ - :math:`\phi` is the intercept of the solutions at the current iteration
+ - :math:`\delta` is the slope of the solutions at the current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H^T H
+ Matrix product of :math:`{H}^T {H}`
Hty : np.ndarray (dim_a + dim_b, )
- Matrix product of H^T y
+ Matrix product of :math:`{H}^T \mathbf{y}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of the cost matrix :math:`{C}`
active_index : list
Indices of active variables
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the current \
+ iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HtH[:, active_index].dot(phi) - Hty) / \
(HtH[:, active_index].dot(delta) - c + 1e-16)
@@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
By taking the Lagrangian form of the problem, we obtain a similar update
as the two-sided relaxed UOT
+
.. math::
- \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T
- \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i}
+
+ \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a})
+ + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \
+ \mathbf{h}_{c i} \delta_u - \mathbf{c}_i}
+
where :
+
- A is the current active set
- - h_{r i} is the ith column of the matrix H_r
- - h_{c i} is the ith column of the matrix H_c
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - c_i is the ith element of cost vector c
- - y is the concatenation of source and target distribution
+ - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r`
+ - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c`
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \
+ :math:`\mathbf{c}`
- :math:`\phi` is the intercept of the solutions in current iteration
- :math:`\delta` is the slope of the solutions in current iteration
- - :math:`\phi_u` is the intercept of Lagrange parameter in current
- iteration
- - :math:`\delta_u` is the slope of Lagrange parameter in current iteration
+ - :math:`\phi_u` is the intercept of Lagrange parameter at the \
+ current iteration
+ - :math:`\delta_u` is the slope of Lagrange parameter at the \
+ current iteration
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : np.ndarray (size(A), )
+ Intercept of the solutions at the current iteration
+ delta : np.ndarray (size(A), )
+ Slope of the solutions at the current iteration
phi_u : np.ndarray (dim_b, )
- Intercept of the Lagrange parameter in current iteration (also linear)
+ Intercept of the Lagrange parameter at the current iteration
delta_u : np.ndarray (dim_b, )
- Slope of the Lagrange parameter in current iteration (also linear)
+ Slope of the Lagrange parameter at the current iteration
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
Hra : np.ndarray (dim_a * dim_b, )
- Matrix product of H_r^T a
+ Matrix product of :math:`H_r^T \mathbf{a}`
c: np.ndarray (dim_a * dim_b, )
- Flattened array of cost matrix C
+ Flattened array of cost matrix :math:`C`
active_index : list
Indices of active variables
current_gamma : float
Value of regularization coefficient at the start of current iteration
+
Returns
-------
next_gamma : float
Value of gamma if a variable is added to active set in next iteration
next_active_index : int
Index of variable to be activated
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \
@@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
def compute_next_removal(phi, delta, current_gamma):
- r""" This function computes the next value of gamma if a variable
- is removed in next iteration of regularization path
+ r""" This function computes the next gamma value if a variable
+ is removed at the next iteration of the regularization path.
+
+ We look for the largest value of the regularization parameter such that
+ an element of the current solution vanishes
- We look for the largest value of gamma such that
- an element of current solution vanishes
.. math::
\max_{j \in A} \frac{\phi_j}{\delta_j}
+
where :
+
- A is the current active set
- - phi_j is the jth element of the intercept of current solution
- - delta_j is the jth elemnt of the slope of current solution
+ - :math:`\phi_j` is the :math:`j` th element of the intercept of the \
+ current solution
+ - :math:`\delta_j` is the :math:`j` th element of the slope of the \
+ current solution
+
+
Parameters
----------
- phi : np.ndarray (|A|, )
- Intercept of the solutions in current iteration (t is piecewise linear)
- delta : np.ndarray (|A|, )
- Slope of the solutions in current iteration (t is piecewise linear)
+ phi : ndarray, shape (size(A), )
+ Intercept of the solution at the current iteration
+ delta : ndarray, shape (size(A), )
+ Slope of the solution at the current iteration
current_gamma : float
- Value of regularization coefficient at the start of current iteration
+ Value of the regularization parameter at the beginning of the \
+ current iteration
+
Returns
-------
next_removal_gamma : float
- Value of gamma if a variable is removed in next iteration
+ Gamma value if a variable is removed at the next iteration
next_removal_index : int
- Index of the variable to remove in next iteration
+ Index of the variable to be removed at the next iteration
+
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
r_candidate = phi / (delta - 1e-16)
r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0
@@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma):
def complement_schur(M_current, b, d, id_pop):
- r""" This function computes the inverse of matrix in regularization path
- using Schur complement
+ r""" This function computes the inverse of the design matrix in the \
+ regularization path using the Schur complement. Two cases may arise:
+
+ Case 1: one variable is added to the active set
+
- Two cases may arise: Firstly one variable is added to the active set
.. math::
M_{k+1}^{-1} =
\begin{bmatrix}
- M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\
- - s^{-1} b^T M_{k}^{-1} & s^{-1}
+ M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \
+ & - M_{k}^{-1} \mathbf{b} s^{-1} \\
+ - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1}
\end{bmatrix}
+
+
where :
- - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and
- :math:`M_k` is the upper left block matrix in Schur formulation
- - b is the upper right block matrix in Schur formulation. In our case,
- b is reduced to a column vector and b^T is the lower left block matrix
- - s is the Schur complement, given by
- :math:`s = d - b^T M_{k}^{-1} b` in our case
-
- Secondly, one variable is removed from the active set
+
+ - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \
+ of the previous iteration
+ - :math:`\mathbf{b}` is the last column of :math:`M_{k}`
+ - :math:`s` is the Schur complement, given by \
+ :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}`
+
+ Case 2: one variable is removed from the active set.
+
.. math::
- M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} -
+ M_{k+1}^{-1} = M^{-1}_{k \backslash q} -
\frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}}
+
where :
- - q is the index of column and row to delete
- - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix
- without qth column and qth row
- - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element
- - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}`
+
+ - :math:`q` is the index of column and row to delete
+ - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \
+ of the :math:`q` th column and :math:`q` th row
+ - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \
+ without the :math:`q` th element
+ - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \
+ row in :math:`M^{-1}_{k}`
+
+
Parameters
----------
- M_current : np.ndarray (|A|-1, |A|-1)
- Inverse matrix in previous iteration
- b : np.ndarray (|A|-1, )
- Upper right matrix in Schur complement, a column vector in our case
+ M_current : ndarray, shape (size(A)-1, size(A)-1)
+ Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \
+ size(A) the size of the active set
+ b : ndarray, shape (size(A)-1, )
+ None for case 2 (removal), last column of :math:`M_{k}` for case 1 \
+ (addition)
d : float
- Lower right matrix in Schur complement, a scalar in our case
- id_pop
+ should be equal to 2 when UOT and 1 for the semi-relaxed OT
+ id_pop : int
Index of the variable to be removed, equal to -1
- if none of the variables is deleted in current iteration
+ if no variable is deleted at the current iteration
+
+
Returns
-------
- M : np.ndarray (|A|, |A|)
- Inverse matrix needed in current iteration
+ M : ndarray, shape (size(A), size(A))
+ Inverse matrix of :math:`H_A^tH_A` of the current iteration
+
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
+
if b is None:
b = M_current[id_pop, :]
b = np.delete(b, id_pop)
@@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop):
def construct_augmented_H(active_index, m, Hc, HrHr):
- r""" This function construct an augmented matrix for the first iteration of
- semi-relaxed regularization path
+ r""" This function constructs an augmented matrix for the first iteration
+ of the semi-relaxed regularization path
.. math::
- Augmented_H =
+ \text{Augmented}_H =
\begin{bmatrix}
0 & H_{c A} \\
H_{c A}^T & H_{r A}^T H_{r A}
\end{bmatrix}
+
where :
- - H_{r A} is the sub-matrix constructed by the columns of H_r
- whose indices belong to the active set A
- - H_{c A} is the sub-matrix constructed by the columns of H_c
- whose indices belong to the active set A
+
+ - :math:`H_{r A}` is the sub-matrix constructed by the columns of \
+ :math:`H_r` whose indices belong to the active set A
+ - :math:`H_{c A}` is the sub-matrix constructed by the columns of \
+ :math:`H_c` whose indices belong to the active set A
+
+
Parameters
----------
active_index : list
- Indices of active variables
+ Indices of the active variables
m : int
Length of the target distribution
Hc : np.ndarray (dim_b, dim_a * dim_b)
- Matrix that computes the sum along the columns of transport plan T
+ Matrix that computes the sum along the columns of the transport plan \
+ :math:`T`
HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
- Matrix product of H_r^T H_r
+ Matrix product of :math:`H_r^T H_r`
+
Returns
-------
- H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|)
+ H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A))
Augmented matrix for the first iteration of the semi-relaxed
regularization path
"""
@@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
r"""This function gives the regularization path of l2-penalized UOT problem
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
s.t.
- t \geq 0
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
- :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
- - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
- see [Chapel et al., 2021] for the design of H. The matrix product Ht
- computes both the source marginal and the target marginal.
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of solutions in the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of regularization coefficients in the regularization path
+
Examples
--------
>>> import ot
@@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
itmax=50000):
r"""This function gives the regularization path of semi-relaxed
- l2-UOT problem
+ l2-UOT problem.
The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
.. math::
- \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+
+ \min_t \gamma \mathbf{c}^T t
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
s.t.
- H_c t = b
- t \geq 0
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
where :
- - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
- - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
- - H_r is a (dim_a, dim_a * dim_b) metric matrix,
- which computes the sum along the rows of transport plan T
- - H_c is a (dim_b, dim_a * dim_b) metric matrix,
- which computes the sum along the columns of transport plan T
- - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`C`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization parameter
+ - :math:`H_r` is a matrix that computes the sum along the rows of \
+ the transport plan :math:`T`
+ - :math:`H_c` is a matrix that computes the sum along the columns of \
+ the transport plan :math:`T`
+ - :math:`\mathbf{t}` is the flattened version of the transport plan \
+ :math:`T`
+
Parameters
----------
a : np.ndarray (dim_a,)
@@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
l2-regularization coefficient
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
Examples
--------
>>> import ot
@@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
n = np.shape(a)[0]
@@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
semi_relaxed=False, itmax=50000):
- r"""This function combines both the semi-relaxed and the fully-relaxed
- regularization paths of l2-UOT problem
+ r"""This function provides all the solutions of the regularization path \
+ of the l2-UOT problem :ref:`[41] <references-regpath>`.
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+
+ .. math::
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2
+
+ s.t.
+ \mathbf{t} \geq 0
+
+ where :
+
+ - :math:`\mathbf{c}` is the flattened version of the cost matrix \
+ :math:`{C}`
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \
+ and :math:`\mathbf{b}`, defined as \
+ :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]`
+ - :math:`{H}` is a design matrix, see :ref:`[41] <references-regpath>` \
+ for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \
+ computes both the source marginal and the target marginals.
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
+ For the semi-relaxed problem, it optimizes the Lasso reformulation of the
+ l2-penalized UOT:
+
+ .. math::
+
+ \min_t \gamma \mathbf{c}^T \mathbf{t}
+ + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2
+
+ s.t.
+ H_c \mathbf{t} = \mathbf{b}
+
+ \mathbf{t} \geq 0
+
Parameters
----------
@@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
reg: float (optional)
l2-regularization coefficient
semi_relaxed : bool (optional)
- Give the semi-relaxed path if true
+ Give the semi-relaxed path if True
itmax: int (optional)
Maximum number of iteration
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Flattened vector of optimal transport matrix
+ Flattened vector of the (unregularized) optimal transport matrix
t_list : list
- List of solutions in regularization path
+ List of all the optimal transport vectors of the regularization path
gamma_list : list
- List of regularization coefficient in regularization path
+ List of the regularization parameters in the path
+
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if semi_relaxed:
t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg,
@@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
def compute_transport_plan(gamma, gamma_list, Pi_list):
r""" Given the regularization path, this function computes the transport
- plan for any value of gamma by the piecewise linearity of the path
+ plan for any value of gamma thanks to the piecewise linearity of the path.
.. math::
t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma)
- where :
- - :math:`\gamma` is the regularization coefficient
+
+ where:
+
+ - :math:`\gamma` is the regularization parameter
- :math:`\phi(\gamma)` is the corresponding intercept
- :math:`\delta(\gamma)` is the corresponding slope
- - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix)
+ - :math:`\mathbf{t}` is the flattened version of the transport matrix
+
Parameters
----------
gamma : float
Regularization coefficient
gamma_list : list
- List of regularization coefficients in regularization path
+ List of regularization parameters of the regularization path
Pi_list : list
- List of solutions in regularization path
+ List of all the solutions of the regularization path
+
Returns
-------
t : np.ndarray (dim_a*dim_b, )
- Transport vector corresponding to the given value of gamma
+ Vectorization of the transport plan corresponding to the given value
+ of gamma
+
Examples
--------
>>> import ot
@@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list):
array([0. , 0. , 0. , 0.19722222, 0.05555556,
0. , 0. , 0.24722222, 0. ])
+
+ .. _references-regpath:
References
----------
- [Chapel et al., 2021]:
- Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
- linear regression.
+ linear regression. NeurIPS.
"""
if gamma >= gamma_list[0]:
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 503cc1e..90c920c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -4,6 +4,7 @@ Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License
from __future__ import division
@@ -1029,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
+
+
+def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ Returns
+ -------
+ gamma : (dim_a, dim_b) array-like
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2)
+ array([[0.3 , 0. ],
+ [0. , 0.07]])
+ >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2)
+ array([[0.25, 0. ],
+ [0. , 0. ]])
+
+
+ .. _references-regpath:
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT
+ """
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=M) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=M) / dim_b
+
+ if G0 is None:
+ G = a[:, None] * b[None, :]
+ else:
+ G = G0
+
+ if log:
+ log = {'err': [], 'G': []}
+
+ if div == 'kl':
+ K = nx.exp(M / - reg_m / 2)
+ elif div == 'l2':
+ K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2,
+ nx.zeros((dim_a, dim_b), type_as=M))
+ else:
+ warnings.warn("The div parameter should be either equal to 'kl' or \
+ 'l2': it has been set to 'kl'.")
+ div = 'kl'
+ K = nx.exp(M / - reg_m / 2)
+
+ for i in range(numItermax):
+ Gprev = G
+
+ if div == 'kl':
+ u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16))
+ v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16))
+ G = G * K * u[:, None] * v[None, :]
+ elif div == 'l2':
+ Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16
+ G = G * K / Gd
+
+ err = nx.sqrt(nx.sum((G - Gprev) ** 2))
+ if log:
+ log['err'].append(err)
+ log['G'].append(G)
+ if verbose:
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
+
+ if log:
+ log['cost'] = nx.sum(G * M)
+ return G, log
+ else:
+ return G
+
+
+def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a maximization-
+ minimization algorithm as proposed in :ref:`[41] <references-regpath>`
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg_m: float
+ Marginal relaxation term > 0
+ div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=True)
+
+ if log:
+ return log_mm['cost'], log_mm
+ else:
+ return log_mm['cost']
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index db59504..02b3fc3 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -1,6 +1,7 @@
"""Tests for module Unbalanced OT with entropy regularization"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#
# License: MIT License
@@ -286,3 +287,52 @@ def test_implemented_methods(nx):
method=method)
barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method=method)
+
+
+def test_mm_convergence(nx):
+ n = 100
+ rng = np.random.RandomState(42)
+ x = rng.randn(n, 2)
+ rng = np.random.RandomState(75)
+ y = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+ b = ot.utils.unif(n)
+
+ M = ot.dist(x, y)
+ M = M / M.max()
+ reg_m = 100
+ a, b, M = nx.from_numpy(a, b, M)
+
+ G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
+ verbose=True, log=True)
+ loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
+ a, b, M, reg_m, div='kl', verbose=True))
+ G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
+ verbose=False, log=True)
+
+ # check if the marginals come close to the true ones when large reg
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+
+ # check if mm_unbalanced2 returns the correct loss
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
+ atol=1e-5)
+
+ # check in case no histogram is provided
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
+ np.testing.assert_allclose(G_kl_null, G_kl)
+ np.testing.assert_allclose(G_l2_null, G_l2)
+
+ # test when G0 is given
+ G0 = ot.emd(a, b, M)
+ reg_m = 10000
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
+ np.testing.assert_allclose(G0, G_kl, atol=1e-05)
+ np.testing.assert_allclose(G0, G_l2, atol=1e-05)