diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2018-08-30 08:26:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-30 08:26:42 +0200 |
commit | da5d07b4949877148f1582a9f0649c34282afa30 (patch) | |
tree | 82c33ee8b09112b6a67ed614e370156e4144628f /test | |
parent | 5180023fc49d15ad83faccc5674d5966fe9a0385 (diff) | |
parent | 15f4b29a91fda1dbd221e6e0a3443431d3d69257 (diff) |
Merge pull request #62 from kilianFatras/stochastic_OT
Debug and speedup SGD stochastic OT
Diffstat (limited to 'test')
-rw-r--r-- | test/test_stochastic.py | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py index f315c88..0128317 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -97,7 +97,6 @@ def test_sag_asgd_sinkhorn(): x = rng.randn(n, 2) u = ot.utils.unif(n) - zero = np.zeros(n) M = ot.dist(x, x) G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", @@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn(): # check constratints np.testing.assert_allclose( - zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag + G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( - zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag + G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( - zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd + G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( - zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd + G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag np.testing.assert_allclose( @@ -137,8 +136,8 @@ def test_stochastic_dual_sgd(): # test sgd n = 10 reg = 1 - numItermax = 300000 - batch_size = 8 + numItermax = 15000 + batch_size = 10 rng = np.random.RandomState(0) x = rng.randn(n, 2) @@ -151,9 +150,9 @@ def test_stochastic_dual_sgd(): # check constratints np.testing.assert_allclose( - u, G.sum(1), atol=1e-02) # cf convergence sgd + u, G.sum(1), atol=1e-03) # cf convergence sgd np.testing.assert_allclose( - u, G.sum(0), atol=1e-02) # cf convergence sgd + u, G.sum(0), atol=1e-03) # cf convergence sgd ############################################################################# @@ -168,13 +167,13 @@ def test_dual_sgd_sinkhorn(): # test all dual algorithms n = 10 reg = 1 - nb_iter = 300000 - batch_size = 8 + nb_iter = 150000 + batch_size = 10 rng = np.random.RandomState(0) +# Test uniform x = rng.randn(n, 2) u = ot.utils.unif(n) - zero = np.zeros(n) M = ot.dist(x, x) G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, @@ -184,8 +183,33 @@ def test_dual_sgd_sinkhorn(): # check constratints np.testing.assert_allclose( - zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-02) # cf convergence sgd + G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) + np.testing.assert_allclose( + G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) + np.testing.assert_allclose( + G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + +# Test gaussian + n = 30 + reg = 1 + batch_size = 30 + + a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std + b = ot.datasets.make_1D_gauss(n, 15, 5) + X_source = np.arange(n, dtype=np.float64) + Y_target = np.arange(n, dtype=np.float64) + M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1))) + M /= M.max() + + G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, + numItermax=nb_iter) + + G_sinkhorn = ot.sinkhorn(a, b, M, reg) + + # check constratints + np.testing.assert_allclose( + G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03) np.testing.assert_allclose( - zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-02) # cf convergence sgd + G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( - G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd + G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd |