summaryrefslogtreecommitdiff
path: root/ot/smooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/smooth.py')
-rw-r--r--ot/smooth.py80
1 files changed, 75 insertions, 5 deletions
diff --git a/ot/smooth.py b/ot/smooth.py
index 8e0ef38..331cfc0 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -24,9 +24,10 @@
# Author: Mathieu Blondel
# Remi Flamary <remi.flamary@unice.fr>
+# Tianlin Liu <t.liu@unibas.ch>
"""
-Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)
+Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers.
Implementation of :
Smooth and Sparse Optimal Transport.
@@ -34,17 +35,31 @@ Mathieu Blondel, Vivien Seguy, Antoine Rolet.
In Proc. of AISTATS 2018.
https://arxiv.org/abs/1710.06276
+(Original code from https://github.com/mblondel/smooth-ot/)
+
+Sparsity-Constrained Optimal Transport.
+Liu, T., Puigcerver, J., & Blondel, M. (2023).
+Sparsity-constrained optimal transport.
+Proceedings of the Eleventh International Conference on
+Learning Representations (ICLR).
+https://arxiv.org/abs/2209.15466
+
+
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
Transport. Proceedings of the Twenty-First International Conference on
Artificial Intelligence and Statistics (AISTATS).
-Original code from https://github.com/mblondel/smooth-ot/
+[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
+Sparsity-constrained optimal transport.
+Proceedings of the Eleventh International Conference on
+Learning Representations (ICLR).
"""
import numpy as np
from scipy.optimize import minimize
from .backend import get_backend
+import ot
def projection_simplex(V, z=1, axis=None):
@@ -209,6 +224,39 @@ class SquaredL2(Regularization):
return 0.5 * self.gamma * np.sum(T ** 2)
+class SparsityConstrained(Regularization):
+ """ Squared L2 regularization with sparsity constraints """
+
+ def __init__(self, max_nz, gamma=1.0):
+ self.max_nz = max_nz
+ self.gamma = gamma
+
+ def delta_Omega(self, X):
+ # For each column of X, find entries that are not among the top max_nz.
+ non_top_indices = np.argpartition(
+ -X, self.max_nz, axis=0)[self.max_nz:]
+ # Set these entries to -inf.
+ if X.ndim == 1:
+ X[non_top_indices] = 0.0
+ else:
+ X[non_top_indices, np.arange(X.shape[1])] = 0.0
+ max_X = np.maximum(X, 0)
+ val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
+ G = max_X / self.gamma
+ return val, G
+
+ def max_Omega(self, X, b):
+ # Project the scaled X onto the simplex with sparsity constraint.
+ G = ot.utils.projection_sparse_simplex(
+ X / (b * self.gamma), self.max_nz, axis=0)
+ val = np.sum(X * G, axis=0)
+ val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
+ return val, G
+
+ def Omega(self, T):
+ return 0.5 * self.gamma * np.sum(T ** 2)
+
+
def dual_obj_grad(alpha, beta, a, b, C, regul):
r"""
Compute objective value and gradients of dual objective.
@@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
return regul.max_Omega(X, b)[1] * b
-def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
- numItermax=500, verbose=False, log=False):
+def smooth_ot_dual(a, b, M, reg, reg_type='l2',
+ method="L-BFGS-B", stopThr=1e-9,
+ numItermax=500, verbose=False, log=False, max_nz=None):
r"""
Solve the regularized OT problem in the dual and return the OT matrix
@@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
:ref:`[2] <references-smooth-ot-dual>`)
- 'l2' : Squared Euclidean regularization
+ - 'sparsity_constrained' : Sparsity-constrained regularization [50]
+ max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
+ not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
@@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+ .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
+
See Also
--------
ot.lp.emd : Unregularized OT
@@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
+ if not isinstance(max_nz, int):
+ raise ValueError(
+ f'max_nz {max_nz} must be an integer')
+ regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')
@@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
return G
-def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
+def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None,
+ method="L-BFGS-B", stopThr=1e-9,
numItermax=500, verbose=False, log=False):
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
@@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
:ref:`[2] <references-smooth-ot-semi-dual>`)
- 'l2' : Squared Euclidean regularization
+ - 'sparsity_constrained' : Sparsity-constrained regularization [50]
+ max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
+ not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
@@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
+ .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
+
See Also
--------
ot.lp.emd : Unregularized OT
@@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
+ elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
+ if not isinstance(max_nz, int):
+ raise ValueError(
+ f'max_nz {max_nz} must be an integer')
+ regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')