diff options
author | Tianlin Liu <tliu@jacobs-alumni.de> | 2023-04-25 12:14:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-25 12:14:29 +0200 |
commit | 42a62c123776e04ee805aefb9afd6d98abdcf192 (patch) | |
tree | d439a1478c2f148c89678adc07736834b41255d4 /ot/smooth.py | |
parent | 03ca4ef659a037e400975e3b2116b637a2d94265 (diff) |
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example
* correct typos; add projection_sparse_simplex
* add gradcheck; merge ot.sparse into ot.smooth.
* reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type
* address pep8 error
* add backends for
* update releases
---------
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/smooth.py')
-rw-r--r-- | ot/smooth.py | 80 |
1 files changed, 75 insertions, 5 deletions
diff --git a/ot/smooth.py b/ot/smooth.py index 8e0ef38..331cfc0 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -24,9 +24,10 @@ # Author: Mathieu Blondel # Remi Flamary <remi.flamary@unice.fr> +# Tianlin Liu <t.liu@unibas.ch> """ -Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) +Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers. Implementation of : Smooth and Sparse Optimal Transport. @@ -34,17 +35,31 @@ Mathieu Blondel, Vivien Seguy, Antoine Rolet. In Proc. of AISTATS 2018. https://arxiv.org/abs/1710.06276 +(Original code from https://github.com/mblondel/smooth-ot/) + +Sparsity-Constrained Optimal Transport. +Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). +https://arxiv.org/abs/2209.15466 + + [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -Original code from https://github.com/mblondel/smooth-ot/ +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). """ import numpy as np from scipy.optimize import minimize from .backend import get_backend +import ot def projection_simplex(V, z=1, axis=None): @@ -209,6 +224,39 @@ class SquaredL2(Regularization): return 0.5 * self.gamma * np.sum(T ** 2) +class SparsityConstrained(Regularization): + """ Squared L2 regularization with sparsity constraints """ + + def __init__(self, max_nz, gamma=1.0): + self.max_nz = max_nz + self.gamma = gamma + + def delta_Omega(self, X): + # For each column of X, find entries that are not among the top max_nz. + non_top_indices = np.argpartition( + -X, self.max_nz, axis=0)[self.max_nz:] + # Set these entries to -inf. + if X.ndim == 1: + X[non_top_indices] = 0.0 + else: + X[non_top_indices, np.arange(X.shape[1])] = 0.0 + max_X = np.maximum(X, 0) + val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + G = max_X / self.gamma + return val, G + + def max_Omega(self, X, b): + # Project the scaled X onto the simplex with sparsity constraint. + G = ot.utils.projection_sparse_simplex( + X / (b * self.gamma), self.max_nz, axis=0) + val = np.sum(X * G, axis=0) + val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) + return val, G + + def Omega(self, T): + return 0.5 * self.gamma * np.sum(T ** 2) + + def dual_obj_grad(alpha, beta, a, b, C, regul): r""" Compute objective value and gradients of dual objective. @@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul): return regul.max_Omega(X, b)[1] * b -def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): +def smooth_ot_dual(a, b, M, reg, reg_type='l2', + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False, max_nz=None): r""" Solve the regularized OT problem in the dual and return the OT matrix @@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, :ref:`[2] <references-smooth-ot-dual>`) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') @@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, return G -def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, +def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, + method="L-BFGS-B", stopThr=1e-9, numItermax=500, verbose=False, log=False): r""" Solve the regularized OT problem in the semi-dual and return the OT matrix @@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= :ref:`[2] <references-smooth-ot-semi-dual>`) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') |