From fc9923dea2706b65ffe15fc86428cd8b53b5feb1 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 21 Mar 2018 09:33:57 +0100 Subject: add tests for ot.uils --- test/test_utils.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) (limited to 'test/test_utils.py') diff --git a/test/test_utils.py b/test/test_utils.py index 1bd37cd..b524ef6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np +import sys def test_parmap(): @@ -123,3 +124,79 @@ def test_clean_zeros(): assert len(a) == n - nz assert len(b) == n - nz2 + + +def test_cost_normalization(): + + C = np.random.rand(10, 10) + + # does nothing + M0 = ot.utils.cost_normalization(C) + np.testing.assert_allclose(C, M0) + + M = ot.utils.cost_normalization(C, 'median') + np.testing.assert_allclose(np.median(M), 1) + + M = ot.utils.cost_normalization(C, 'max') + np.testing.assert_allclose(M.max(), 1) + + M = ot.utils.cost_normalization(C, 'log') + np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + + M = ot.utils.cost_normalization(C, 'loglog') + np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + + +def test_check_params(): + + res1 = ot.utils.check_params(first='OK', second=20) + assert res1 is True + + res0 = ot.utils.check_params(first='OK', second=None) + assert res0 is False + + +def test_deprecated_func(): + + @ot.utils.deprecated('deprecated text for fun') + def fun(): + pass + + def fun2(): + pass + + @ot.utils.deprecated('deprecated text for class') + class Class(): + pass + + if sys.version_info < (3, 5): + print('Not tested') + else: + assert ot.utils._is_deprecated(fun) is True + + assert ot.utils._is_deprecated(fun2) is False + + +def test_BaseEstimator(): + + class Class(ot.utils.BaseEstimator): + + def __init__(self, first='spam', second='eggs'): + + self.first = first + self.second = second + + cl = Class() + + names = cl._get_param_names() + assert 'first' in names + assert 'second' in names + + params = cl.get_params() + assert 'first' in params + assert 'second' in params + + params['first'] = 'spam again' + cl.set_params(**params) + + assert cl.first == 'spam again' -- cgit v1.2.3