summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py81
1 files changed, 80 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 3423a7e..3343028 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend, NumpyBackend
+from .backend import get_backend, Backend, NumpyBackend, JaxBackend
__time_tic_toc = time.time()
@@ -117,6 +117,85 @@ def proj_simplex(v, z=1):
return w
+def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None):
+ r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.
+
+ .. math::
+ P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2
+
+ Parameters
+ ----------
+ V: 1-dim or 2-dim ndarray
+ z: float or array
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
+ axis: None or int
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`
+
+ Returns
+ -------
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
+
+ References:
+ Sparse projections onto the simplex
+ Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
+ ICML 2013
+ https://arxiv.org/abs/1206.1529
+ """
+ if nx is None:
+ nx = get_backend(V)
+ if V.ndim == 1:
+ return projection_sparse_simplex(
+ # V[nx.newaxis, :], max_nz, z, axis=1).ravel()
+ V[None, :], max_nz, z, axis=1).ravel()
+
+ if V.ndim > 2:
+ raise ValueError('V.ndim must be <= 2')
+
+ if axis == 1:
+ # For each row of V, find top max_nz values; arrange the
+ # corresponding column indices such that their values are
+ # in a descending order.
+ max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:]
+ max_nz_indices = nx.flip(max_nz_indices, axis=1)
+
+ row_indices = nx.arange(V.shape[0])
+ row_indices = row_indices.reshape(-1, 1)
+ print(row_indices.shape)
+ # Extract the top max_nz values for each row
+ # and then project to simplex.
+ U = V[row_indices, max_nz_indices]
+ z = nx.ones(len(U)) * z
+ cssv = nx.cumsum(U, axis=1) - z[:, None]
+ ind = nx.arange(max_nz) + 1
+ cond = U - cssv / ind > 0
+ # rho = nx.count_nonzero(cond, axis=1)
+ rho = nx.sum(cond, axis=1)
+ theta = cssv[nx.arange(len(U)), rho - 1] / rho
+ nz_projection = nx.maximum(U - theta[:, None], 0)
+
+ # Put the projection of max_nz_values to their original column indices
+ # while keeping other values zero.
+ sparse_projection = nx.zeros(V.shape, type_as=nz_projection)
+
+ if isinstance(nx, JaxBackend):
+ # in Jax, we need to use the `at` property of `jax.numpy.ndarray`
+ # to do in-place array modificatons.
+ sparse_projection = sparse_projection.at[
+ row_indices, max_nz_indices].set(nz_projection)
+ else:
+ sparse_projection[row_indices, max_nz_indices] = nz_projection
+ return sparse_projection
+
+ elif axis == 0:
+ return projection_sparse_simplex(V.T, max_nz, z, axis=1).T
+
+ else:
+ V = V.ravel().reshape(1, -1)
+ return projection_sparse_simplex(V, max_nz, z, axis=1).ravel()
+
+
def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).