diff options
author | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
---|---|---|
committer | Antoine Rolet <antoine.rolet@gmail.com> | 2017-09-05 15:30:50 +0900 |
commit | 13dfb3ddbbd8926b4751b82dd41c5570253b1f07 (patch) | |
tree | b28098e98640c64483a599103e2fdb5df46d2c79 /test/test_bregman.py | |
parent | 185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 (diff) | |
parent | 16697047eff9326a0ecb483317c13a854a3d3a71 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py new file mode 100644 index 0000000..4a800fd --- /dev/null +++ b/test/test_bregman.py @@ -0,0 +1,137 @@ +"""Tests for module bregman on OT with bregman projections """ + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import ot + + +def test_sinkhorn(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + + # check constratints + np.testing.assert_allclose( + u, G.sum(1), atol=1e-05) # cf convergence sinkhorn + np.testing.assert_allclose( + u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + + +def test_sinkhorn_empty(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) + # check constratints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, + method='sinkhorn_stabilized', verbose=True, log=True) + # check constratints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + G, log = ot.sinkhorn( + [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling', + verbose=True, log=True) + # check constratints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + + +def test_sinkhorn_variants(): + # test sinkhorn + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) + Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) + Ges = ot.sinkhorn( + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) + Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) + + # check values + np.testing.assert_allclose(G0, Gs, atol=1e-05) + np.testing.assert_allclose(G0, Ges, atol=1e-05) + np.testing.assert_allclose(G0, Gerr) + + +def test_bary(): + + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + # wasserstein + reg = 1e-3 + bary_wass = ot.bregman.barycenter(A, M, reg, weights) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + + ot.bregman.barycenter(A, M, reg, log=True, verbose=True) + + +def test_unmix(): + + n_bins = 50 # nb bins + + # Gaussian distributions + a1 = ot.datasets.get_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std + a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10) + + a = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) + + # creating matrix A containing all distributions + D = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + M0 = ot.utils.dist0(2) + M0 /= M0.max() + h0 = ot.unif(2) + + # wasserstein + reg = 1e-3 + um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,) + + np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) + np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) + + ot.bregman.unmix(a, D, M, M0, h0, reg, + 1, alpha=0.01, log=True, verbose=True) |