diff options
Diffstat (limited to 'test/test_utils.py')
-rw-r--r-- | test/test_utils.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 658214d..87f4dc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,25 +270,31 @@ def test_clean_zeros(): assert len(b) == n - nz2 -def test_cost_normalization(): +def test_cost_normalization(nx): C = np.random.rand(10, 10) + C1 = nx.from_numpy(C) # does nothing - M0 = ot.utils.cost_normalization(C) - np.testing.assert_allclose(C, M0) + M0 = ot.utils.cost_normalization(C1) + M1 = nx.to_numpy(M0) + np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C, 'median') - np.testing.assert_allclose(np.median(M), 1) + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) - M = ot.utils.cost_normalization(C, 'max') - np.testing.assert_allclose(M.max(), 1) + M = ot.utils.cost_normalization(C1, 'max') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), 1) - M = ot.utils.cost_normalization(C, 'log') - np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + M = ot.utils.cost_normalization(C1, 'log') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + C).max()) - M = ot.utils.cost_normalization(C, 'loglog') - np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + M = ot.utils.cost_normalization(C1, 'loglog') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) def test_check_params(): |