From 42a62c123776e04ee805aefb9afd6d98abdcf192 Mon Sep 17 00:00:00 2001 From: Tianlin Liu Date: Tue, 25 Apr 2023 12:14:29 +0200 Subject: [FEAT] add the sparsity-constrained optimal transport funtionality and example (#459) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add sparsity-constrained ot funtionality and example * correct typos; add projection_sparse_simplex * add gradcheck; merge ot.sparse into ot.smooth. * reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type * address pep8 error * add backends for * update releases --------- Co-authored-by: RĂ©mi Flamary --- test/test_smooth.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) (limited to 'test') diff --git a/test/test_smooth.py b/test/test_smooth.py index 31e0b2e..dbdd405 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -7,6 +7,7 @@ import numpy as np import ot import pytest +from scipy.optimize import check_grad def test_smooth_ot_dual(): @@ -23,6 +24,7 @@ def test_smooth_ot_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -43,6 +45,24 @@ def test_smooth_ot_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + # sparsity-constrained regularisation + max_nz = 2 + Gsc, log = ot.smooth.smooth_ot_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + reg_type='sparsity_constrained', + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less( + np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + def test_smooth_ot_semi_dual(): @@ -58,6 +78,7 @@ def test_smooth_ot_semi_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -77,3 +98,43 @@ def test_smooth_ot_semi_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + + # sparsity-constrained regularisation + max_nz = 2 + Gsc = ot.smooth.smooth_ot_semi_dual( + u, u, M, 1, reg_type='sparsity_constrained', + max_nz=max_nz, stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_sparsity_constrained_gradient(): + max_nz = 5 + regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz) + rng = np.random.RandomState(0) + X = rng.randn(10,) + b = 0.5 + + def delta_omega_func(X): + return regularizer.delta_Omega(X)[0] + + def delta_omega_grad(X): + return regularizer.delta_Omega(X)[1] + + dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X) + np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07) + + def max_omega_func(X, b): + return regularizer.max_Omega(X, b)[0] + + def max_omega_grad(X, b): + return regularizer.max_Omega(X, b)[1] + + semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b) + np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07) diff --git a/test/test_utils.py b/test/test_utils.py index 31b12ef..658214d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -41,6 +41,59 @@ def test_proj_simplex(nx): np.testing.assert_allclose(l1, l2, atol=1e-5) +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.utils.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce consistent results + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) + + def test_parmap(): n = 10 -- cgit v1.2.3