summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTianlin Liu <tliu@jacobs-alumni.de>2023-04-25 12:14:29 +0200
committerGitHub <noreply@github.com>2023-04-25 12:14:29 +0200
commit42a62c123776e04ee805aefb9afd6d98abdcf192 (patch)
treed439a1478c2f148c89678adc07736834b41255d4
parent03ca4ef659a037e400975e3b2116b637a2d94265 (diff)
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example * correct typos; add projection_sparse_simplex * add gradcheck; merge ot.sparse into ot.smooth. * reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type * address pep8 error * add backends for * update releases --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--README.md2
-rw-r--r--RELEASES.md3
-rw-r--r--examples/plot_OT_1D_smooth.py51
-rw-r--r--ot/smooth.py80
-rw-r--r--ot/utils.py81
-rw-r--r--test/test_smooth.py61
-rw-r--r--test/test_utils.py53
7 files changed, 291 insertions, 40 deletions
diff --git a/README.md b/README.md
index 2a81e95..f0fb4bd 100644
--- a/README.md
+++ b/README.md
@@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
+
+[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
diff --git a/RELEASES.md b/RELEASES.md
index d912215..b18fdc3 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -3,9 +3,8 @@
## 0.9.1dev
#### New features
-
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
-
+- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
#### Closed issues
- Fix circleci-redirector action and codecov (PR #460)
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index 5415e4f..ff51b8a 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
"""
================================
-Smooth optimal transport example
+Smooth and sparse OT example
================================
-This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
-and their visualization.
+This example illustrates the computation of
+Smooth and Sparse (KL an L2 reg.) OT and
+sparsity-constrained OT, together with their visualizations.
"""
@@ -58,32 +59,6 @@ pl.legend()
pl.figure(2, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
-##############################################################################
-# Solve EMD
-# ---------
-
-
-#%% EMD
-
-G0 = ot.emd(a, b, M)
-
-pl.figure(3, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
-
-##############################################################################
-# Solve Sinkhorn
-# --------------
-
-
-#%% Sinkhorn
-
-lambd = 2e-3
-Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
-
-pl.figure(4, figsize=(5, 5))
-ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
-
-pl.show()
##############################################################################
# Solve Smooth OT
@@ -95,18 +70,30 @@ pl.show()
lambd = 2e-3
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
-pl.figure(5, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
pl.show()
-#%% Smooth OT with KL regularization
+#%% Smooth OT with squared l2 regularization
lambd = 1e-1
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
-pl.figure(6, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
pl.show()
+
+#%% Sparsity-constrained OT
+
+lambd = 1e-1
+
+max_nz = 2 # two non-zero entries are permitted per column of the OT plan
+Gsc = ot.smooth.smooth_ot_dual(
+ a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz)
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.')
+
+pl.show()
diff --git a/ot/smooth.py b/ot/smooth.py
index 8e0ef38..331cfc0 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -24,9 +24,10 @@
# Author: Mathieu Blondel
# Remi Flamary <remi.flamary@unice.fr>
+# Tianlin Liu <t.liu@unibas.ch>
"""
-Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)
+Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers.
Implementation of :
Smooth and Sparse Optimal Transport.
@@ -34,17 +35,31 @@ Mathieu Blondel, Vivien Seguy, Antoine Rolet.
In Proc. of AISTATS 2018.
https://arxiv.org/abs/1710.06276
+(Original code from https://github.com/mblondel/smooth-ot/)
+
+Sparsity-Constrained Optimal Transport.
+Liu, T., Puigcerver, J., & Blondel, M. (2023).
+Sparsity-constrained optimal transport.
+Proceedings of the Eleventh International Conference on
+Learning Representations (ICLR).
+https://arxiv.org/abs/2209.15466
+
+
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
Transport. Proceedings of the Twenty-First International Conference on
Artificial Intelligence and Statistics (AISTATS).
-Original code from https://github.com/mblondel/smooth-ot/
+[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
+Sparsity-constrained optimal transport.
+Proceedings of the Eleventh International Conference on
+Learning Representations (ICLR).
"""
import numpy as np
from scipy.optimize import minimize
from .backend import get_backend
+import ot
def projection_simplex(V, z=1, axis=None):
@@ -209,6 +224,39 @@ class SquaredL2(Regularization):
return 0.5 * self.gamma * np.sum(T ** 2)
+class SparsityConstrained(Regularization):
+ """ Squared L2 regularization with sparsity constraints """
+
+ def __init__(self, max_nz, gamma=1.0):
+ self.max_nz = max_nz
+ self.gamma = gamma
+
+ def delta_Omega(self, X):
+ # For each column of X, find entries that are not among the top max_nz.
+ non_top_indices = np.argpartition(
+ -X, self.max_nz, axis=0)[self.max_nz:]
+ # Set these entries to -inf.
+ if X.ndim == 1:
+ X[non_top_indices] = 0.0
+ else:
+ X[non_top_indices, np.arange(X.shape[1])] = 0.0
+ max_X = np.maximum(X, 0)
+ val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
+ G = max_X / self.gamma
+ return val, G
+
+ def max_Omega(self, X, b):
+ # Project the scaled X onto the simplex with sparsity constraint.
+ G = ot.utils.projection_sparse_simplex(
+ X / (b * self.gamma), self.max_nz, axis=0)
+ val = np.sum(X * G, axis=0)
+ val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return 0.5 * self.gamma * np.sum(T ** 2)
+
+
def dual_obj_grad(alpha, beta, a, b, C, regul):
r"""
Compute objective value and gradients of dual objective.
@@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
return regul.max_Omega(X, b)[1] * b
-def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
- numItermax=500, verbose=False, log=False):
+def smooth_ot_dual(a, b, M, reg, reg_type='l2',
+ method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False, max_nz=None):
r"""
Solve the regularized OT problem in the dual and return the OT matrix
@@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
:ref:`[2] <references-smooth-ot-dual>`)
- 'l2' : Squared Euclidean regularization
+ - 'sparsity_constrained' : Sparsity-constrained regularization [50]
+ max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
+ not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
@@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+ .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
+
See Also
--------
ot.lp.emd : Unregularized OT
@@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
+ if not isinstance(max_nz, int):
+ raise ValueError(
+ f'max_nz {max_nz} must be an integer')
+ regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')
@@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
return G
-def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None,
+ method="L-BFGS-B", stopThr=1e-9,
numItermax=500, verbose=False, log=False):
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
@@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
:ref:`[2] <references-smooth-ot-semi-dual>`)
- 'l2' : Squared Euclidean regularization
+ - 'sparsity_constrained' : Sparsity-constrained regularization [50]
+ max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
+ not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
@@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+ .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
+
See Also
--------
ot.lp.emd : Unregularized OT
@@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
+ if not isinstance(max_nz, int):
+ raise ValueError(
+ f'max_nz {max_nz} must be an integer')
+ regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')
diff --git a/ot/utils.py b/ot/utils.py
index 3423a7e..3343028 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend, NumpyBackend
+from .backend import get_backend, Backend, NumpyBackend, JaxBackend
__time_tic_toc = time.time()
@@ -117,6 +117,85 @@ def proj_simplex(v, z=1):
return w
+def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None):
+ r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.
+
+ .. math::
+ P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2
+
+ Parameters
+ ----------
+ V: 1-dim or 2-dim ndarray
+ z: float or array
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
+ axis: None or int
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`
+
+ Returns
+ -------
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
+
+ References:
+ Sparse projections onto the simplex
+ Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
+ ICML 2013
+ https://arxiv.org/abs/1206.1529
+ """
+ if nx is None:
+ nx = get_backend(V)
+ if V.ndim == 1:
+ return projection_sparse_simplex(
+ # V[nx.newaxis, :], max_nz, z, axis=1).ravel()
+ V[None, :], max_nz, z, axis=1).ravel()
+
+ if V.ndim > 2:
+ raise ValueError('V.ndim must be <= 2')
+
+ if axis == 1:
+ # For each row of V, find top max_nz values; arrange the
+ # corresponding column indices such that their values are
+ # in a descending order.
+ max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:]
+ max_nz_indices = nx.flip(max_nz_indices, axis=1)
+
+ row_indices = nx.arange(V.shape[0])
+ row_indices = row_indices.reshape(-1, 1)
+ print(row_indices.shape)
+ # Extract the top max_nz values for each row
+ # and then project to simplex.
+ U = V[row_indices, max_nz_indices]
+ z = nx.ones(len(U)) * z
+ cssv = nx.cumsum(U, axis=1) - z[:, None]
+ ind = nx.arange(max_nz) + 1
+ cond = U - cssv / ind > 0
+ # rho = nx.count_nonzero(cond, axis=1)
+ rho = nx.sum(cond, axis=1)
+ theta = cssv[nx.arange(len(U)), rho - 1] / rho
+ nz_projection = nx.maximum(U - theta[:, None], 0)
+
+ # Put the projection of max_nz_values to their original column indices
+ # while keeping other values zero.
+ sparse_projection = nx.zeros(V.shape, type_as=nz_projection)
+
+ if isinstance(nx, JaxBackend):
+ # in Jax, we need to use the `at` property of `jax.numpy.ndarray`
+ # to do in-place array modificatons.
+ sparse_projection = sparse_projection.at[
+ row_indices, max_nz_indices].set(nz_projection)
+ else:
+ sparse_projection[row_indices, max_nz_indices] = nz_projection
+ return sparse_projection
+
+ elif axis == 0:
+ return projection_sparse_simplex(V.T, max_nz, z, axis=1).T
+
+ else:
+ V = V.ravel().reshape(1, -1)
+ return projection_sparse_simplex(V, max_nz, z, axis=1).ravel()
+
+
def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 31e0b2e..dbdd405 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
import pytest
+from scipy.optimize import check_grad
def test_smooth_ot_dual():
@@ -23,6 +24,7 @@ def test_smooth_ot_dual():
with pytest.raises(NotImplementedError):
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none')
+ # squared l2 regularisation
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constraints
@@ -43,6 +45,24 @@ def test_smooth_ot_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+ # sparsity-constrained regularisation
+ max_nz = 2
+ Gsc, log = ot.smooth.smooth_ot_dual(
+ u, u, M, 1,
+ max_nz=max_nz,
+ log=True,
+ reg_type='sparsity_constrained',
+ stopThr=1e-10)
+
+ # check marginal constraints
+ np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03)
+ np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03)
+
+ # check sparsity constraints
+ np.testing.assert_array_less(
+ np.sum(Gsc > 0, axis=0),
+ np.ones(n) * max_nz + 1)
+
def test_smooth_ot_semi_dual():
@@ -58,6 +78,7 @@ def test_smooth_ot_semi_dual():
with pytest.raises(NotImplementedError):
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none')
+ # squared l2 regularisation
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constraints
@@ -77,3 +98,43 @@ def test_smooth_ot_semi_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+
+ # sparsity-constrained regularisation
+ max_nz = 2
+ Gsc = ot.smooth.smooth_ot_semi_dual(
+ u, u, M, 1, reg_type='sparsity_constrained',
+ max_nz=max_nz, stopThr=1e-10)
+
+ # check marginal constraints
+ np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03)
+ np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03)
+
+ # check sparsity constraints
+ np.testing.assert_array_less(np.sum(Gsc > 0, axis=0),
+ np.ones(n) * max_nz + 1)
+
+
+def test_sparsity_constrained_gradient():
+ max_nz = 5
+ regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz)
+ rng = np.random.RandomState(0)
+ X = rng.randn(10,)
+ b = 0.5
+
+ def delta_omega_func(X):
+ return regularizer.delta_Omega(X)[0]
+
+ def delta_omega_grad(X):
+ return regularizer.delta_Omega(X)[1]
+
+ dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X)
+ np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07)
+
+ def max_omega_func(X, b):
+ return regularizer.max_Omega(X, b)[0]
+
+ def max_omega_grad(X, b):
+ return regularizer.max_Omega(X, b)[1]
+
+ semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b)
+ np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07)
diff --git a/test/test_utils.py b/test/test_utils.py
index 31b12ef..658214d 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -41,6 +41,59 @@ def test_proj_simplex(nx):
np.testing.assert_allclose(l1, l2, atol=1e-5)
+def test_projection_sparse_simplex():
+
+ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None):
+ r"""This is an equivalent but less efficient version
+ of ot.utils.projection_sparse_simplex, as it uses two
+ sorts instead of one.
+ """
+
+ if axis == 0:
+ # For each column of X, find top max_nz values and
+ # their corresponding indices. This incurs a sort.
+ max_nz_indices = np.argpartition(
+ X,
+ kth=-max_nz,
+ axis=0)[-max_nz:]
+
+ max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
+
+ # Project the top max_nz values onto the simplex.
+ # This incurs a second sort.
+ G_nz_values = ot.smooth.projection_simplex(
+ max_nz_values, z=z, axis=0)
+
+ # Put the projection of max_nz_values to their original indices
+ # and set all other values zero.
+ G = np.zeros_like(X)
+ G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
+ return G
+ elif axis == 1:
+ return double_sort_projection_sparse_simplex(
+ X.T, max_nz, z, axis=0).T
+
+ else:
+ X = X.ravel().reshape(-1, 1)
+ return double_sort_projection_sparse_simplex(
+ X, max_nz, z, axis=0).ravel()
+
+ m, n = 5, 10
+ rng = np.random.RandomState(0)
+ X = rng.uniform(size=(m, n))
+ max_nz = 3
+
+ for axis in [0, 1, None]:
+ slow_sparse_proj = double_sort_projection_sparse_simplex(
+ X, max_nz, axis=axis)
+ fast_sparse_proj = ot.utils.projection_sparse_simplex(
+ X, max_nz, axis=axis)
+
+ # check that two versions produce consistent results
+ np.testing.assert_allclose(
+ slow_sparse_proj, fast_sparse_proj)
+
+
def test_parmap():
n = 10