diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-07-19 17:04:14 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-08-28 15:36:37 +0200 |
commit | cfdbbd21642c6082164b84db78c2ead07499a113 (patch) | |
tree | 11ac0f20c9d5881802ac348be86dff65e5db67d7 /test/test_bregman.py | |
parent | abfe183a49caaf74a07e595ac40920dae05a3c22 (diff) |
remove square in convergence check
add unbalanced with stabilization
add unbalanced tests with stabilization
fix doctest examples
add xvfb in travis
remove explicit call xvfb in travis
change alpha to reg_m
minor flake8
remove redundant sink definitions + better doc and naming
add stabilized unbalanced barycenter + add not converged warnings
add test for stable barycenter
add generic barycenter func + make method funcs private
fix typo + add method test for barycenters
fix doc examples + add xml to gitignore
fix whitespace in example
change logsumexp import - scipy deprecation warning
fix doctest
improve naming + add stable barycenter in bregman
add test for stable bar + test the method arg in bregman
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 72 |
1 files changed, 65 insertions, 7 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 83ebba8..f70df10 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ import numpy as np import ot +import pytest def test_sinkhorn(): @@ -71,13 +72,11 @@ def test_sinkhorn_variants(): Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10) Ges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) - Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, Gerr) np.testing.assert_allclose(G0, G_green, atol=1e-5) print(G0, G_green) @@ -96,18 +95,17 @@ def test_sinkhorn_variants_log(): Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) - Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Ges, atol=1e-05) - np.testing.assert_allclose(G0, Gerr) np.testing.assert_allclose(G0, G_green, atol=1e-5) print(G0, G_green) -def test_bary(): +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_barycenter(method): n_bins = 100 # nb bins @@ -126,14 +124,42 @@ def test_bary(): weights = np.array([1 - alpha, alpha]) # wasserstein - reg = 1e-3 - bary_wass = ot.bregman.barycenter(A, M, reg, weights) + reg = 1e-2 + bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method) np.testing.assert_allclose(1, np.sum(bary_wass)) ot.bregman.barycenter(A, M, reg, log=True, verbose=True) +def test_barycenter_stabilization(): + + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + # wasserstein + reg = 1e-2 + bar_stable = ot.bregman.barycenter(A, M, reg, weights, + method="sinkhorn_stabilized", + stopThr=1e-8) + bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", + stopThr=1e-8) + np.testing.assert_allclose(bar, bar_stable) + + def test_wasserstein_bary_2d(): size = 100 # size of a square image @@ -279,3 +305,35 @@ def test_stabilized_vs_sinkhorn_multidim(): method="sinkhorn", log=True) np.testing.assert_allclose(G, G2) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) + A = rng.rand(n, 2) + M = ot.dist(x, x) + epsilon = 1. + + for method in IMPLEMENTED_METHODS: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) + ot.bregman.barycenter(A, M, reg=epsilon, method=method) + for method in ONLY_1D_methods: + ot.bregman.sinkhorn(a, b, M, epsilon, method=method) + with pytest.raises(ValueError): + ot.bregman.sinkhorn2(a, b, M, epsilon, method=method) |