diff options
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 81 |
1 files changed, 80 insertions, 1 deletions
diff --git a/ot/utils.py b/ot/utils.py index 3423a7e..3343028 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend, Backend, NumpyBackend +from .backend import get_backend, Backend, NumpyBackend, JaxBackend __time_tic_toc = time.time() @@ -117,6 +117,85 @@ def proj_simplex(v, z=1): return w +def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None): + r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. + + .. math:: + P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: 1-dim or 2-dim ndarray + z: float or array + If array, len(z) must be compatible with :math:`\mathbf{V}` + axis: None or int + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape + + References: + Sparse projections onto the simplex + Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch + ICML 2013 + https://arxiv.org/abs/1206.1529 + """ + if nx is None: + nx = get_backend(V) + if V.ndim == 1: + return projection_sparse_simplex( + # V[nx.newaxis, :], max_nz, z, axis=1).ravel() + V[None, :], max_nz, z, axis=1).ravel() + + if V.ndim > 2: + raise ValueError('V.ndim must be <= 2') + + if axis == 1: + # For each row of V, find top max_nz values; arrange the + # corresponding column indices such that their values are + # in a descending order. + max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:] + max_nz_indices = nx.flip(max_nz_indices, axis=1) + + row_indices = nx.arange(V.shape[0]) + row_indices = row_indices.reshape(-1, 1) + print(row_indices.shape) + # Extract the top max_nz values for each row + # and then project to simplex. + U = V[row_indices, max_nz_indices] + z = nx.ones(len(U)) * z + cssv = nx.cumsum(U, axis=1) - z[:, None] + ind = nx.arange(max_nz) + 1 + cond = U - cssv / ind > 0 + # rho = nx.count_nonzero(cond, axis=1) + rho = nx.sum(cond, axis=1) + theta = cssv[nx.arange(len(U)), rho - 1] / rho + nz_projection = nx.maximum(U - theta[:, None], 0) + + # Put the projection of max_nz_values to their original column indices + # while keeping other values zero. + sparse_projection = nx.zeros(V.shape, type_as=nz_projection) + + if isinstance(nx, JaxBackend): + # in Jax, we need to use the `at` property of `jax.numpy.ndarray` + # to do in-place array modificatons. + sparse_projection = sparse_projection.at[ + row_indices, max_nz_indices].set(nz_projection) + else: + sparse_projection[row_indices, max_nz_indices] = nz_projection + return sparse_projection + + elif axis == 0: + return projection_sparse_simplex(V.T, max_nz, z, axis=1).T + + else: + V = V.ravel().reshape(1, -1) + return projection_sparse_simplex(V, max_nz, z, axis=1).ravel() + + def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). |