diff options
author | Tianlin Liu <tliu@jacobs-alumni.de> | 2023-04-25 12:14:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-25 12:14:29 +0200 |
commit | 42a62c123776e04ee805aefb9afd6d98abdcf192 (patch) | |
tree | d439a1478c2f148c89678adc07736834b41255d4 | |
parent | 03ca4ef659a037e400975e3b2116b637a2d94265 (diff) |
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example
* correct typos; add projection_sparse_simplex
* add gradcheck; merge ot.sparse into ot.smooth.
* reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type
* address pep8 error
* add backends for
* update releases
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | RELEASES.md | 3 | ||||
-rw-r--r-- | examples/plot_OT_1D_smooth.py | 51 | ||||
-rw-r--r-- | ot/smooth.py | 80 | ||||
-rw-r--r-- | ot/utils.py | 81 | ||||
-rw-r--r-- | test/test_smooth.py | 61 | ||||
-rw-r--r-- | test/test_utils.py | 53 |
7 files changed, 291 insertions, 40 deletions
@@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. + +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). diff --git a/RELEASES.md b/RELEASES.md index d912215..b18fdc3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,9 +3,8 @@ ## 0.9.1dev #### New features - - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) - +- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) #### Closed issues - Fix circleci-redirector action and codecov (PR #460) diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 5415e4f..ff51b8a 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """ ================================ -Smooth optimal transport example +Smooth and sparse OT example ================================ -This example illustrates the computation of EMD, Sinkhorn and smooth OT plans -and their visualization. +This example illustrates the computation of +Smooth and Sparse (KL an L2 reg.) OT and +sparsity-constrained OT, together with their visualizations. """ @@ -58,32 +59,6 @@ pl.legend() pl.figure(2, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') -############################################################################## -# Solve EMD -# --------- - - -#%% EMD - -G0 = ot.emd(a, b, M) - -pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') - -############################################################################## -# Solve Sinkhorn -# -------------- - - -#%% Sinkhorn - -lambd = 2e-3 -Gs = ot.sinkhorn(a, b, M, lambd, verbose=True) - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn') - -pl.show() ############################################################################## # Solve Smooth OT @@ -95,18 +70,30 @@ pl.show() lambd = 2e-3 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl') -pl.figure(5, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.') pl.show() -#%% Smooth OT with KL regularization +#%% Smooth OT with squared l2 regularization lambd = 1e-1 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') -pl.figure(6, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') pl.show() + +#%% Sparsity-constrained OT + +lambd = 1e-1 + +max_nz = 2 # two non-zero entries are permitted per column of the OT plan +Gsc = ot.smooth.smooth_ot_dual( + a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz) +pl.figure(5, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') + +pl.show() diff --git a/ot/smooth.py b/ot/smooth.py index 8e0ef38..331cfc0 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -24,9 +24,10 @@ # Author: Mathieu Blondel # Remi Flamary <remi.flamary@unice.fr> +# Tianlin Liu <t.liu@unibas.ch> """ -Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) +Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers. Implementation of : Smooth and Sparse Optimal Transport. @@ -34,17 +35,31 @@ Mathieu Blondel, Vivien Seguy, Antoine Rolet. In Proc. of AISTATS 2018. https://arxiv.org/abs/1710.06276 +(Original code from https://github.com/mblondel/smooth-ot/) + +Sparsity-Constrained Optimal Transport. +Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). +https://arxiv.org/abs/2209.15466 + + [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -Original code from https://github.com/mblondel/smooth-ot/ +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). """ import numpy as np from scipy.optimize import minimize from .backend import get_backend +import ot def projection_simplex(V, z=1, axis=None): @@ -209,6 +224,39 @@ class SquaredL2(Regularization): return 0.5 * self.gamma * np.sum(T ** 2) +class SparsityConstrained(Regularization): + """ Squared L2 regularization with sparsity constraints """ + + def __init__(self, max_nz, gamma=1.0): + self.max_nz = max_nz + self.gamma = gamma + + def delta_Omega(self, X): + # For each column of X, find entries that are not among the top max_nz. + non_top_indices = np.argpartition( + -X, self.max_nz, axis=0)[self.max_nz:] + # Set these entries to -inf. + if X.ndim == 1: + X[non_top_indices] = 0.0 + else: + X[non_top_indices, np.arange(X.shape[1])] = 0.0 + max_X = np.maximum(X, 0) + val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + G = max_X / self.gamma + return val, G + + def max_Omega(self, X, b): + # Project the scaled X onto the simplex with sparsity constraint. + G = ot.utils.projection_sparse_simplex( + X / (b * self.gamma), self.max_nz, axis=0) + val = np.sum(X * G, axis=0) + val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) + return val, G + + def Omega(self, T): + return 0.5 * self.gamma * np.sum(T ** 2) + + def dual_obj_grad(alpha, beta, a, b, C, regul): r""" Compute objective value and gradients of dual objective. @@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul): return regul.max_Omega(X, b)[1] * b -def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): +def smooth_ot_dual(a, b, M, reg, reg_type='l2', + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False, max_nz=None): r""" Solve the regularized OT problem in the dual and return the OT matrix @@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, :ref:`[2] <references-smooth-ot-dual>`) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') @@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, return G -def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, +def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, + method="L-BFGS-B", stopThr=1e-9, numItermax=500, verbose=False, log=False): r""" Solve the regularized OT problem in the semi-dual and return the OT matrix @@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= :ref:`[2] <references-smooth-ot-semi-dual>`) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') diff --git a/ot/utils.py b/ot/utils.py index 3423a7e..3343028 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend, Backend, NumpyBackend +from .backend import get_backend, Backend, NumpyBackend, JaxBackend __time_tic_toc = time.time() @@ -117,6 +117,85 @@ def proj_simplex(v, z=1): return w +def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None): + r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. + + .. math:: + P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: 1-dim or 2-dim ndarray + z: float or array + If array, len(z) must be compatible with :math:`\mathbf{V}` + axis: None or int + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape + + References: + Sparse projections onto the simplex + Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch + ICML 2013 + https://arxiv.org/abs/1206.1529 + """ + if nx is None: + nx = get_backend(V) + if V.ndim == 1: + return projection_sparse_simplex( + # V[nx.newaxis, :], max_nz, z, axis=1).ravel() + V[None, :], max_nz, z, axis=1).ravel() + + if V.ndim > 2: + raise ValueError('V.ndim must be <= 2') + + if axis == 1: + # For each row of V, find top max_nz values; arrange the + # corresponding column indices such that their values are + # in a descending order. + max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:] + max_nz_indices = nx.flip(max_nz_indices, axis=1) + + row_indices = nx.arange(V.shape[0]) + row_indices = row_indices.reshape(-1, 1) + print(row_indices.shape) + # Extract the top max_nz values for each row + # and then project to simplex. + U = V[row_indices, max_nz_indices] + z = nx.ones(len(U)) * z + cssv = nx.cumsum(U, axis=1) - z[:, None] + ind = nx.arange(max_nz) + 1 + cond = U - cssv / ind > 0 + # rho = nx.count_nonzero(cond, axis=1) + rho = nx.sum(cond, axis=1) + theta = cssv[nx.arange(len(U)), rho - 1] / rho + nz_projection = nx.maximum(U - theta[:, None], 0) + + # Put the projection of max_nz_values to their original column indices + # while keeping other values zero. + sparse_projection = nx.zeros(V.shape, type_as=nz_projection) + + if isinstance(nx, JaxBackend): + # in Jax, we need to use the `at` property of `jax.numpy.ndarray` + # to do in-place array modificatons. + sparse_projection = sparse_projection.at[ + row_indices, max_nz_indices].set(nz_projection) + else: + sparse_projection[row_indices, max_nz_indices] = nz_projection + return sparse_projection + + elif axis == 0: + return projection_sparse_simplex(V.T, max_nz, z, axis=1).T + + else: + V = V.ravel().reshape(1, -1) + return projection_sparse_simplex(V, max_nz, z, axis=1).ravel() + + def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). diff --git a/test/test_smooth.py b/test/test_smooth.py index 31e0b2e..dbdd405 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -7,6 +7,7 @@ import numpy as np import ot import pytest +from scipy.optimize import check_grad def test_smooth_ot_dual(): @@ -23,6 +24,7 @@ def test_smooth_ot_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -43,6 +45,24 @@ def test_smooth_ot_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + # sparsity-constrained regularisation + max_nz = 2 + Gsc, log = ot.smooth.smooth_ot_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + reg_type='sparsity_constrained', + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less( + np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + def test_smooth_ot_semi_dual(): @@ -58,6 +78,7 @@ def test_smooth_ot_semi_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -77,3 +98,43 @@ def test_smooth_ot_semi_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + + # sparsity-constrained regularisation + max_nz = 2 + Gsc = ot.smooth.smooth_ot_semi_dual( + u, u, M, 1, reg_type='sparsity_constrained', + max_nz=max_nz, stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_sparsity_constrained_gradient(): + max_nz = 5 + regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz) + rng = np.random.RandomState(0) + X = rng.randn(10,) + b = 0.5 + + def delta_omega_func(X): + return regularizer.delta_Omega(X)[0] + + def delta_omega_grad(X): + return regularizer.delta_Omega(X)[1] + + dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X) + np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07) + + def max_omega_func(X, b): + return regularizer.max_Omega(X, b)[0] + + def max_omega_grad(X, b): + return regularizer.max_Omega(X, b)[1] + + semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b) + np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07) diff --git a/test/test_utils.py b/test/test_utils.py index 31b12ef..658214d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -41,6 +41,59 @@ def test_proj_simplex(nx): np.testing.assert_allclose(l1, l2, atol=1e-5) +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.utils.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce consistent results + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) + + def test_parmap(): n = 10 |