summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py72
1 files changed, 65 insertions, 7 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 83ebba8..f70df10 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
+import pytest
def test_sinkhorn():
@@ -71,13 +72,11 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
@@ -96,18 +95,17 @@ def test_sinkhorn_variants_log():
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
- Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
-def test_bary():
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
n_bins = 100 # nb bins
@@ -126,14 +124,42 @@ def test_bary():
weights = np.array([1 - alpha, alpha])
# wasserstein
- reg = 1e-3
- bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_barycenter_stabilization():
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
def test_wasserstein_bary_2d():
size = 100 # size of a square image
@@ -279,3 +305,35 @@ def test_stabilized_vs_sinkhorn_multidim():
method="sinkhorn", log=True)
np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)