diff options
Diffstat (limited to 'test/test_utils.py')
-rw-r--r-- | test/test_utils.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 31b12ef..658214d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -41,6 +41,59 @@ def test_proj_simplex(nx): np.testing.assert_allclose(l1, l2, atol=1e-5) +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.utils.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce consistent results + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) + + def test_parmap(): n = 10 |