summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py46
-rw-r--r--test/test_gromov.py79
-rw-r--r--test/test_utils.py77
3 files changed, 200 insertions, 2 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 593dc53..3022721 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -326,8 +326,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 150
- nt = 200
+ ns = 60
+ nt = 120
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
@@ -444,6 +444,48 @@ def test_mapping_transport_class():
assert len(otda.log_.keys()) != 0
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_linear_mapping_class():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ otmap = ot.da.LinearTransport()
+
+ otmap.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otmap, "A_")
+ assert hasattr(otmap, "B_")
+ assert hasattr(otmap, "A1_")
+ assert hasattr(otmap, "B1_")
+
+ Xst = otmap.transform(Xs=Xs)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 625e62a..bb23469 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -36,6 +36,18 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,70 @@ def test_entropic_gromov():
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_gromov_entropic_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
diff --git a/test/test_utils.py b/test/test_utils.py
index 1bd37cd..b524ef6 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
+import sys
def test_parmap():
@@ -123,3 +124,79 @@ def test_clean_zeros():
assert len(a) == n - nz
assert len(b) == n - nz2
+
+
+def test_cost_normalization():
+
+ C = np.random.rand(10, 10)
+
+ # does nothing
+ M0 = ot.utils.cost_normalization(C)
+ np.testing.assert_allclose(C, M0)
+
+ M = ot.utils.cost_normalization(C, 'median')
+ np.testing.assert_allclose(np.median(M), 1)
+
+ M = ot.utils.cost_normalization(C, 'max')
+ np.testing.assert_allclose(M.max(), 1)
+
+ M = ot.utils.cost_normalization(C, 'log')
+ np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+
+ M = ot.utils.cost_normalization(C, 'loglog')
+ np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+
+
+def test_check_params():
+
+ res1 = ot.utils.check_params(first='OK', second=20)
+ assert res1 is True
+
+ res0 = ot.utils.check_params(first='OK', second=None)
+ assert res0 is False
+
+
+def test_deprecated_func():
+
+ @ot.utils.deprecated('deprecated text for fun')
+ def fun():
+ pass
+
+ def fun2():
+ pass
+
+ @ot.utils.deprecated('deprecated text for class')
+ class Class():
+ pass
+
+ if sys.version_info < (3, 5):
+ print('Not tested')
+ else:
+ assert ot.utils._is_deprecated(fun) is True
+
+ assert ot.utils._is_deprecated(fun2) is False
+
+
+def test_BaseEstimator():
+
+ class Class(ot.utils.BaseEstimator):
+
+ def __init__(self, first='spam', second='eggs'):
+
+ self.first = first
+ self.second = second
+
+ cl = Class()
+
+ names = cl._get_param_names()
+ assert 'first' in names
+ assert 'second' in names
+
+ params = cl.get_params()
+ assert 'first' in params
+ assert 'second' in params
+
+ params['first'] = 'spam again'
+ cl.set_params(**params)
+
+ assert cl.first == 'spam again'