summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 31b12ef..658214d 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -41,6 +41,59 @@ def test_proj_simplex(nx):
np.testing.assert_allclose(l1, l2, atol=1e-5)
+def test_projection_sparse_simplex():
+
+ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None):
+ r"""This is an equivalent but less efficient version
+ of ot.utils.projection_sparse_simplex, as it uses two
+ sorts instead of one.
+ """
+
+ if axis == 0:
+ # For each column of X, find top max_nz values and
+ # their corresponding indices. This incurs a sort.
+ max_nz_indices = np.argpartition(
+ X,
+ kth=-max_nz,
+ axis=0)[-max_nz:]
+
+ max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
+
+ # Project the top max_nz values onto the simplex.
+ # This incurs a second sort.
+ G_nz_values = ot.smooth.projection_simplex(
+ max_nz_values, z=z, axis=0)
+
+ # Put the projection of max_nz_values to their original indices
+ # and set all other values zero.
+ G = np.zeros_like(X)
+ G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
+ return G
+ elif axis == 1:
+ return double_sort_projection_sparse_simplex(
+ X.T, max_nz, z, axis=0).T
+
+ else:
+ X = X.ravel().reshape(-1, 1)
+ return double_sort_projection_sparse_simplex(
+ X, max_nz, z, axis=0).ravel()
+
+ m, n = 5, 10
+ rng = np.random.RandomState(0)
+ X = rng.uniform(size=(m, n))
+ max_nz = 3
+
+ for axis in [0, 1, None]:
+ slow_sparse_proj = double_sort_projection_sparse_simplex(
+ X, max_nz, axis=axis)
+ fast_sparse_proj = ot.utils.projection_sparse_simplex(
+ X, max_nz, axis=axis)
+
+ # check that two versions produce consistent results
+ np.testing.assert_allclose(
+ slow_sparse_proj, fast_sparse_proj)
+
+
def test_parmap():
n = 10