diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_da.py | 46 | ||||
-rw-r--r-- | test/test_gromov.py | 79 | ||||
-rw-r--r-- | test/test_utils.py | 77 |
3 files changed, 200 insertions, 2 deletions
diff --git a/test/test_da.py b/test/test_da.py index 593dc53..3022721 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -326,8 +326,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 150 - nt = 200 + ns = 60 + nt = 120 Xs, ys = get_data_classif('3gauss', ns) Xt, yt = get_data_classif('3gauss2', nt) @@ -444,6 +444,48 @@ def test_mapping_transport_class(): assert len(otda.log_.keys()) != 0 +def test_linear_mapping(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + A, b = ot.da.OT_mapping_linear(Xs, Xt) + + Xst = Xs.dot(A) + b + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_linear_mapping_class(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + otmap = ot.da.LinearTransport() + + otmap.fit(Xs=Xs, Xt=Xt) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") + + Xst = otmap.transform(Xs=Xs) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples diff --git a/test/test_gromov.py b/test/test_gromov.py index 625e62a..bb23469 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -36,6 +36,18 @@ def test_gromov(): np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,70 @@ def test_entropic_gromov(): p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_gromov_entropic_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
diff --git a/test/test_utils.py b/test/test_utils.py index 1bd37cd..b524ef6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np +import sys def test_parmap(): @@ -123,3 +124,79 @@ def test_clean_zeros(): assert len(a) == n - nz assert len(b) == n - nz2 + + +def test_cost_normalization(): + + C = np.random.rand(10, 10) + + # does nothing + M0 = ot.utils.cost_normalization(C) + np.testing.assert_allclose(C, M0) + + M = ot.utils.cost_normalization(C, 'median') + np.testing.assert_allclose(np.median(M), 1) + + M = ot.utils.cost_normalization(C, 'max') + np.testing.assert_allclose(M.max(), 1) + + M = ot.utils.cost_normalization(C, 'log') + np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + + M = ot.utils.cost_normalization(C, 'loglog') + np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + + +def test_check_params(): + + res1 = ot.utils.check_params(first='OK', second=20) + assert res1 is True + + res0 = ot.utils.check_params(first='OK', second=None) + assert res0 is False + + +def test_deprecated_func(): + + @ot.utils.deprecated('deprecated text for fun') + def fun(): + pass + + def fun2(): + pass + + @ot.utils.deprecated('deprecated text for class') + class Class(): + pass + + if sys.version_info < (3, 5): + print('Not tested') + else: + assert ot.utils._is_deprecated(fun) is True + + assert ot.utils._is_deprecated(fun2) is False + + +def test_BaseEstimator(): + + class Class(ot.utils.BaseEstimator): + + def __init__(self, first='spam', second='eggs'): + + self.first = first + self.second = second + + cl = Class() + + names = cl._get_param_names() + assert 'first' in names + assert 'second' in names + + params = cl.get_params() + assert 'first' in params + assert 'second' in params + + params['first'] = 'spam again' + cl.set_params(**params) + + assert cl.first == 'spam again' |