diff options
Diffstat (limited to 'test/test_smooth.py')
-rw-r--r-- | test/test_smooth.py | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/test/test_smooth.py b/test/test_smooth.py index 31e0b2e..dbdd405 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -7,6 +7,7 @@ import numpy as np import ot import pytest +from scipy.optimize import check_grad def test_smooth_ot_dual(): @@ -23,6 +24,7 @@ def test_smooth_ot_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -43,6 +45,24 @@ def test_smooth_ot_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + # sparsity-constrained regularisation + max_nz = 2 + Gsc, log = ot.smooth.smooth_ot_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + reg_type='sparsity_constrained', + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less( + np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + def test_smooth_ot_semi_dual(): @@ -58,6 +78,7 @@ def test_smooth_ot_semi_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -77,3 +98,43 @@ def test_smooth_ot_semi_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + + # sparsity-constrained regularisation + max_nz = 2 + Gsc = ot.smooth.smooth_ot_semi_dual( + u, u, M, 1, reg_type='sparsity_constrained', + max_nz=max_nz, stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_sparsity_constrained_gradient(): + max_nz = 5 + regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz) + rng = np.random.RandomState(0) + X = rng.randn(10,) + b = 0.5 + + def delta_omega_func(X): + return regularizer.delta_Omega(X)[0] + + def delta_omega_grad(X): + return regularizer.delta_Omega(X)[1] + + dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X) + np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07) + + def max_omega_func(X, b): + return regularizer.max_Omega(X, b)[0] + + def max_omega_grad(X, b): + return regularizer.max_Omega(X, b)[1] + + semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b) + np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07) |