From 0280a3441b09c781035cda3b74213ec92026ff9e Mon Sep 17 00:00:00 2001 From: Kilian Date: Fri, 15 Nov 2019 16:10:37 +0100 Subject: fix bug numItermax emd in cg --- test/test_optim.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) (limited to 'test/test_optim.py') diff --git a/test/test_optim.py b/test/test_optim.py index ae31e1f..aade36e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -37,6 +37,39 @@ def test_conditional_gradient(): np.testing.assert_allclose(b, G.sum(0)) +def test_conditional_gradient2(): + n = 4000 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + mu_t = np.array([4, 4]) + cov_t = np.array([[1, -.8], [-.8, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) + xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + + a, b = np.ones((n,)) / n, np.ones((n,)) / n + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + def f(G): + return 0.5 * np.sum(G**2) + + def df(G): + return G + + reg = 1e-1 + + G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000, + verbose=True, log=True) + + np.testing.assert_allclose(a, G.sum(1)) + np.testing.assert_allclose(b, G.sum(0)) + + def test_generalized_conditional_gradient(): n_bins = 100 # nb bins -- cgit v1.2.3