summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorFrancisco Muñoz <femunoz@dim.uchile.cl>2023-05-10 07:41:15 -0400
committerGitHub <noreply@github.com>2023-05-10 13:41:15 +0200
commit8cc8dd2e8e13022b03bcd013becc08e7e18c404a (patch)
tree8305bdebda3479c1717c69e6e7133e07b4b444bd /test
parent03341c6953c06608ba17d6bf7cd35666bc069988 (diff)
[FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends * [TEST] Modify the 'cost_normalization' test to multiple backends * [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends * [TEST] Update backend tests for method 'median' * [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend * [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. * [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy * [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy * [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py6
-rw-r--r--test/test_utils.py28
2 files changed, 23 insertions, 11 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index fedc62f..799ac54 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -222,6 +222,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.mean(M)
with pytest.raises(NotImplementedError):
+ nx.median(M)
+ with pytest.raises(NotImplementedError):
nx.std(M)
with pytest.raises(NotImplementedError):
nx.linspace(0, 1, 50)
@@ -519,6 +521,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('mean')
+ A = nx.median(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('median')
+
A = nx.std(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('std')
diff --git a/test/test_utils.py b/test/test_utils.py
index 658214d..87f4dc4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -270,25 +270,31 @@ def test_clean_zeros():
assert len(b) == n - nz2
-def test_cost_normalization():
+def test_cost_normalization(nx):
C = np.random.rand(10, 10)
+ C1 = nx.from_numpy(C)
# does nothing
- M0 = ot.utils.cost_normalization(C)
- np.testing.assert_allclose(C, M0)
+ M0 = ot.utils.cost_normalization(C1)
+ M1 = nx.to_numpy(M0)
+ np.testing.assert_allclose(C, M1)
- M = ot.utils.cost_normalization(C, 'median')
- np.testing.assert_allclose(np.median(M), 1)
+ M = ot.utils.cost_normalization(C1, 'median')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(np.median(M1), 1)
- M = ot.utils.cost_normalization(C, 'max')
- np.testing.assert_allclose(M.max(), 1)
+ M = ot.utils.cost_normalization(C1, 'max')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), 1)
- M = ot.utils.cost_normalization(C, 'log')
- np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+ M = ot.utils.cost_normalization(C1, 'log')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + C).max())
- M = ot.utils.cost_normalization(C, 'loglog')
- np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+ M = ot.utils.cost_normalization(C1, 'loglog')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max())
def test_check_params():