diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 10:08:17 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-21 10:08:17 +0100 |
commit | 55aaf7874c651235d44c34b89337df7694e55014 (patch) | |
tree | 821b86edbf0e14c11885f466a249dc1263090e80 /test | |
parent | 927395b40dae98bcf027b601b6df48a4318cfef2 (diff) |
add test gromov + debug sklearn Basestimator
Diffstat (limited to 'test')
-rw-r--r-- | test/test_da.py | 4 | ||||
-rw-r--r-- | test/test_gromov.py | 25 |
2 files changed, 27 insertions, 2 deletions
diff --git a/test/test_da.py b/test/test_da.py index a9d6d34..3022721 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -326,8 +326,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 150 - nt = 200 + ns = 60 + nt = 120 Xs, ys = get_data_classif('3gauss', ns) Xt, yt = get_data_classif('3gauss2', nt) diff --git a/test/test_gromov.py b/test/test_gromov.py index 625e62a..0384ee1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -36,6 +36,18 @@ def test_gromov(): np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,16 @@ def test_entropic_gromov(): p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
|