diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-11-18 11:13:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-18 11:13:03 +0100 |
commit | bbd8f2046ec42751eba8e5356366aded74a2930d (patch) | |
tree | 6e8b8ceb9bfc59257360ef2d2d54417a8041a30e /test | |
parent | 3635fc46d6fc55e6fa30b33ad07fe092dfd23241 (diff) | |
parent | 0280a3441b09c781035cda3b74213ec92026ff9e (diff) |
Merge pull request #108 from kilianFatras/master
[MRG] Fix log and nbiter bug in gromov_wasserstein and gromov_wasserstein2
Diffstat (limited to 'test')
-rw-r--r-- | test/test_gromov.py | 4 | ||||
-rw-r--r-- | test/test_optim.py | 33 |
2 files changed, 37 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py index 70fa83f..43da9fc 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -44,10 +44,14 @@ def test_gromov(): gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+
G = log['T']
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+
# check constratints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
diff --git a/test/test_optim.py b/test/test_optim.py index ae31e1f..aade36e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -37,6 +37,39 @@ def test_conditional_gradient(): np.testing.assert_allclose(b, G.sum(0)) +def test_conditional_gradient2(): + n = 4000 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([4, 4]) + cov_t = np.array([[1, -.8], [-.8, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) + xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + + a, b = np.ones((n,)) / n, np.ones((n,)) / n + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + def f(G): + return 0.5 * np.sum(G**2) + + def df(G): + return G + + reg = 1e-1 + + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + verbose=True, log=True) + + np.testing.assert_allclose(a, G.sum(1)) + np.testing.assert_allclose(b, G.sum(0)) + + def test_generalized_conditional_gradient(): n_bins = 100 # nb bins |