diff options
author | Francisco Muñoz <femunoz@dim.uchile.cl> | 2023-05-10 07:41:15 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-10 13:41:15 +0200 |
commit | 8cc8dd2e8e13022b03bcd013becc08e7e18c404a (patch) | |
tree | 8305bdebda3479c1717c69e6e7133e07b4b444bd /test/test_backend.py | |
parent | 03341c6953c06608ba17d6bf7cd35666bc069988 (diff) |
[FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends
* [TEST] Modify the 'cost_normalization' test to multiple backends
* [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends
* [TEST] Update backend tests for method 'median'
* [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend
* [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend.
* [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy
* [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy
* [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy
* Add changes to RELEASES.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_backend.py')
-rw-r--r-- | test/test_backend.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/test/test_backend.py b/test/test_backend.py index fedc62f..799ac54 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -222,6 +222,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): + nx.median(M) + with pytest.raises(NotImplementedError): nx.std(M) with pytest.raises(NotImplementedError): nx.linspace(0, 1, 50) @@ -519,6 +521,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') + A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('std') |