summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorTianlin Liu <tliu@jacobs-alumni.de>2023-04-25 12:14:29 +0200
committerGitHub <noreply@github.com>2023-04-25 12:14:29 +0200
commit42a62c123776e04ee805aefb9afd6d98abdcf192 (patch)
treed439a1478c2f148c89678adc07736834b41255d4 /test
parent03ca4ef659a037e400975e3b2116b637a2d94265 (diff)
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example * correct typos; add projection_sparse_simplex * add gradcheck; merge ot.sparse into ot.smooth. * reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type * address pep8 error * add backends for * update releases --------- Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_smooth.py61
-rw-r--r--test/test_utils.py53
2 files changed, 114 insertions, 0 deletions
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 31e0b2e..dbdd405 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
import pytest
+from scipy.optimize import check_grad
def test_smooth_ot_dual():
@@ -23,6 +24,7 @@ def test_smooth_ot_dual():
with pytest.raises(NotImplementedError):
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none')
+ # squared l2 regularisation
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constraints
@@ -43,6 +45,24 @@ def test_smooth_ot_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+ # sparsity-constrained regularisation
+ max_nz = 2
+ Gsc, log = ot.smooth.smooth_ot_dual(
+ u, u, M, 1,
+ max_nz=max_nz,
+ log=True,
+ reg_type='sparsity_constrained',
+ stopThr=1e-10)
+
+ # check marginal constraints
+ np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03)
+ np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03)
+
+ # check sparsity constraints
+ np.testing.assert_array_less(
+ np.sum(Gsc > 0, axis=0),
+ np.ones(n) * max_nz + 1)
+
def test_smooth_ot_semi_dual():
@@ -58,6 +78,7 @@ def test_smooth_ot_semi_dual():
with pytest.raises(NotImplementedError):
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none')
+ # squared l2 regularisation
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
# check constraints
@@ -77,3 +98,43 @@ def test_smooth_ot_semi_dual():
G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
np.testing.assert_allclose(G, G2, atol=1e-05)
+
+ # sparsity-constrained regularisation
+ max_nz = 2
+ Gsc = ot.smooth.smooth_ot_semi_dual(
+ u, u, M, 1, reg_type='sparsity_constrained',
+ max_nz=max_nz, stopThr=1e-10)
+
+ # check marginal constraints
+ np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03)
+ np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03)
+
+ # check sparsity constraints
+ np.testing.assert_array_less(np.sum(Gsc > 0, axis=0),
+ np.ones(n) * max_nz + 1)
+
+
+def test_sparsity_constrained_gradient():
+ max_nz = 5
+ regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz)
+ rng = np.random.RandomState(0)
+ X = rng.randn(10,)
+ b = 0.5
+
+ def delta_omega_func(X):
+ return regularizer.delta_Omega(X)[0]
+
+ def delta_omega_grad(X):
+ return regularizer.delta_Omega(X)[1]
+
+ dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X)
+ np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07)
+
+ def max_omega_func(X, b):
+ return regularizer.max_Omega(X, b)[0]
+
+ def max_omega_grad(X, b):
+ return regularizer.max_Omega(X, b)[1]
+
+ semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b)
+ np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07)
diff --git a/test/test_utils.py b/test/test_utils.py
index 31b12ef..658214d 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -41,6 +41,59 @@ def test_proj_simplex(nx):
np.testing.assert_allclose(l1, l2, atol=1e-5)
+def test_projection_sparse_simplex():
+
+ def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None):
+ r"""This is an equivalent but less efficient version
+ of ot.utils.projection_sparse_simplex, as it uses two
+ sorts instead of one.
+ """
+
+ if axis == 0:
+ # For each column of X, find top max_nz values and
+ # their corresponding indices. This incurs a sort.
+ max_nz_indices = np.argpartition(
+ X,
+ kth=-max_nz,
+ axis=0)[-max_nz:]
+
+ max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
+
+ # Project the top max_nz values onto the simplex.
+ # This incurs a second sort.
+ G_nz_values = ot.smooth.projection_simplex(
+ max_nz_values, z=z, axis=0)
+
+ # Put the projection of max_nz_values to their original indices
+ # and set all other values zero.
+ G = np.zeros_like(X)
+ G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
+ return G
+ elif axis == 1:
+ return double_sort_projection_sparse_simplex(
+ X.T, max_nz, z, axis=0).T
+
+ else:
+ X = X.ravel().reshape(-1, 1)
+ return double_sort_projection_sparse_simplex(
+ X, max_nz, z, axis=0).ravel()
+
+ m, n = 5, 10
+ rng = np.random.RandomState(0)
+ X = rng.uniform(size=(m, n))
+ max_nz = 3
+
+ for axis in [0, 1, None]:
+ slow_sparse_proj = double_sort_projection_sparse_simplex(
+ X, max_nz, axis=axis)
+ fast_sparse_proj = ot.utils.projection_sparse_simplex(
+ X, max_nz, axis=axis)
+
+ # check that two versions produce consistent results
+ np.testing.assert_allclose(
+ slow_sparse_proj, fast_sparse_proj)
+
+
def test_parmap():
n = 10