summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 658214d..87f4dc4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -270,25 +270,31 @@ def test_clean_zeros():
assert len(b) == n - nz2
-def test_cost_normalization():
+def test_cost_normalization(nx):
C = np.random.rand(10, 10)
+ C1 = nx.from_numpy(C)
# does nothing
- M0 = ot.utils.cost_normalization(C)
- np.testing.assert_allclose(C, M0)
+ M0 = ot.utils.cost_normalization(C1)
+ M1 = nx.to_numpy(M0)
+ np.testing.assert_allclose(C, M1)
- M = ot.utils.cost_normalization(C, 'median')
- np.testing.assert_allclose(np.median(M), 1)
+ M = ot.utils.cost_normalization(C1, 'median')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(np.median(M1), 1)
- M = ot.utils.cost_normalization(C, 'max')
- np.testing.assert_allclose(M.max(), 1)
+ M = ot.utils.cost_normalization(C1, 'max')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), 1)
- M = ot.utils.cost_normalization(C, 'log')
- np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+ M = ot.utils.cost_normalization(C1, 'log')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + C).max())
- M = ot.utils.cost_normalization(C, 'loglog')
- np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+ M = ot.utils.cost_normalization(C1, 'loglog')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max())
def test_check_params():