From 8cc8dd2e8e13022b03bcd013becc08e7e18c404a Mon Sep 17 00:00:00 2001 From: Francisco Muñoz Date: Wed, 10 May 2023 07:41:15 -0400 Subject: [FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends * [TEST] Modify the 'cost_normalization' test to multiple backends * [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends * [TEST] Update backend tests for method 'median' * [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend * [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. * [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy * [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy * [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary --- test/test_backend.py | 6 ++++++ test/test_utils.py | 28 +++++++++++++++++----------- 2 files changed, 23 insertions(+), 11 deletions(-) (limited to 'test') diff --git a/test/test_backend.py b/test/test_backend.py index fedc62f..799ac54 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -221,6 +221,8 @@ def test_empty_backend(): nx.argmin(M) with pytest.raises(NotImplementedError): nx.mean(M) + with pytest.raises(NotImplementedError): + nx.median(M) with pytest.raises(NotImplementedError): nx.std(M) with pytest.raises(NotImplementedError): @@ -519,6 +521,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') + A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('std') diff --git a/test/test_utils.py b/test/test_utils.py index 658214d..87f4dc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,25 +270,31 @@ def test_clean_zeros(): assert len(b) == n - nz2 -def test_cost_normalization(): +def test_cost_normalization(nx): C = np.random.rand(10, 10) + C1 = nx.from_numpy(C) # does nothing - M0 = ot.utils.cost_normalization(C) - np.testing.assert_allclose(C, M0) + M0 = ot.utils.cost_normalization(C1) + M1 = nx.to_numpy(M0) + np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C, 'median') - np.testing.assert_allclose(np.median(M), 1) + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) - M = ot.utils.cost_normalization(C, 'max') - np.testing.assert_allclose(M.max(), 1) + M = ot.utils.cost_normalization(C1, 'max') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), 1) - M = ot.utils.cost_normalization(C, 'log') - np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + M = ot.utils.cost_normalization(C1, 'log') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + C).max()) - M = ot.utils.cost_normalization(C, 'loglog') - np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + M = ot.utils.cost_normalization(C1, 'loglog') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) def test_check_params(): -- cgit v1.2.3