summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 09:33:57 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 09:33:57 +0100
commitfc9923dea2706b65ffe15fc86428cd8b53b5feb1 (patch)
treee47fd6e6a9b8e0afacfdc3035917bd1da1fe6afd /test/test_utils.py
parent5efdf008865ea347775708b637d933e048d663ec (diff)
add tests for ot.uils
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 1bd37cd..b524ef6 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
+import sys
def test_parmap():
@@ -123,3 +124,79 @@ def test_clean_zeros():
assert len(a) == n - nz
assert len(b) == n - nz2
+
+
+def test_cost_normalization():
+
+ C = np.random.rand(10, 10)
+
+ # does nothing
+ M0 = ot.utils.cost_normalization(C)
+ np.testing.assert_allclose(C, M0)
+
+ M = ot.utils.cost_normalization(C, 'median')
+ np.testing.assert_allclose(np.median(M), 1)
+
+ M = ot.utils.cost_normalization(C, 'max')
+ np.testing.assert_allclose(M.max(), 1)
+
+ M = ot.utils.cost_normalization(C, 'log')
+ np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+
+ M = ot.utils.cost_normalization(C, 'loglog')
+ np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+
+
+def test_check_params():
+
+ res1 = ot.utils.check_params(first='OK', second=20)
+ assert res1 is True
+
+ res0 = ot.utils.check_params(first='OK', second=None)
+ assert res0 is False
+
+
+def test_deprecated_func():
+
+ @ot.utils.deprecated('deprecated text for fun')
+ def fun():
+ pass
+
+ def fun2():
+ pass
+
+ @ot.utils.deprecated('deprecated text for class')
+ class Class():
+ pass
+
+ if sys.version_info < (3, 5):
+ print('Not tested')
+ else:
+ assert ot.utils._is_deprecated(fun) is True
+
+ assert ot.utils._is_deprecated(fun2) is False
+
+
+def test_BaseEstimator():
+
+ class Class(ot.utils.BaseEstimator):
+
+ def __init__(self, first='spam', second='eggs'):
+
+ self.first = first
+ self.second = second
+
+ cl = Class()
+
+ names = cl._get_param_names()
+ assert 'first' in names
+ assert 'second' in names
+
+ params = cl.get_params()
+ assert 'first' in params
+ assert 'second' in params
+
+ params['first'] = 'spam again'
+ cl.set_params(**params)
+
+ assert cl.first == 'spam again'