summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2018-08-30 08:26:42 +0200
committerGitHub <noreply@github.com>2018-08-30 08:26:42 +0200
commitda5d07b4949877148f1582a9f0649c34282afa30 (patch)
tree82c33ee8b09112b6a67ed614e370156e4144628f /test
parent5180023fc49d15ad83faccc5674d5966fe9a0385 (diff)
parent15f4b29a91fda1dbd221e6e0a3443431d3d69257 (diff)
Merge pull request #62 from kilianFatras/stochastic_OT
Debug and speedup SGD stochastic OT
Diffstat (limited to 'test')
-rw-r--r--test/test_stochastic.py54
1 files changed, 39 insertions, 15 deletions
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index f315c88..0128317 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -97,7 +97,6 @@ def test_sag_asgd_sinkhorn():
x = rng.randn(n, 2)
u = ot.utils.unif(n)
- zero = np.zeros(n)
M = ot.dist(x, x)
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
@@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
@@ -137,8 +136,8 @@ def test_stochastic_dual_sgd():
# test sgd
n = 10
reg = 1
- numItermax = 300000
- batch_size = 8
+ numItermax = 15000
+ batch_size = 10
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -151,9 +150,9 @@ def test_stochastic_dual_sgd():
# check constratints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-02) # cf convergence sgd
+ u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-02) # cf convergence sgd
+ u, G.sum(0), atol=1e-03) # cf convergence sgd
#############################################################################
@@ -168,13 +167,13 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 300000
- batch_size = 8
+ nb_iter = 150000
+ batch_size = 10
rng = np.random.RandomState(0)
+# Test uniform
x = rng.randn(n, 2)
u = ot.utils.unif(n)
- zero = np.zeros(n)
M = ot.dist(x, x)
G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
@@ -184,8 +183,33 @@ def test_dual_sgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-02) # cf convergence sgd
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+# Test gaussian
+ n = 30
+ reg = 1
+ batch_size = 30
+
+ a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n, 15, 5)
+ X_source = np.arange(n, dtype=np.float64)
+ Y_target = np.arange(n, dtype=np.float64)
+ M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))
+ M /= M.max()
+
+ G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
+ numItermax=nb_iter)
+
+ G_sinkhorn = ot.sinkhorn(a, b, M, reg)
+
+ # check constratints
+ np.testing.assert_allclose(
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(
- zero, (G_sgd - G_sinkhorn).sum(0), atol=1e-02) # cf convergence sgd
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
np.testing.assert_allclose(
- G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd