summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-19 17:04:14 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-08-28 15:36:37 +0200
commitcfdbbd21642c6082164b84db78c2ead07499a113 (patch)
tree11ac0f20c9d5881802ac348be86dff65e5db67d7 /test/test_unbalanced.py
parentabfe183a49caaf74a07e595ac40920dae05a3c22 (diff)
remove square in convergence check
add unbalanced with stabilization add unbalanced tests with stabilization fix doctest examples add xvfb in travis remove explicit call xvfb in travis change alpha to reg_m minor flake8 remove redundant sink definitions + better doc and naming add stabilized unbalanced barycenter + add not converged warnings add test for stable barycenter add generic barycenter func + make method funcs private fix typo + add method test for barycenters fix doc examples + add xml to gitignore fix whitespace in example change logsumexp import - scipy deprecation warning fix doctest improve naming + add stable barycenter in bregman add test for stable bar + test the method arg in bregman
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py163
1 files changed, 119 insertions, 44 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 1395fe1..ca1efba 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -7,9 +7,12 @@
import numpy as np
import ot
import pytest
+from ot.unbalanced import barycenter_unbalanced
+from scipy.special import logsumexp
-@pytest.mark.parametrize("method", ["sinkhorn"])
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -23,29 +26,35 @@ def test_unbalanced_convergence(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, method=method,
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
log=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
- u_final = (a / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
-@pytest.mark.parametrize("method", ["sinkhorn"])
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_multiple_inputs(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -59,28 +68,59 @@ def test_unbalanced_multiple_inputs(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- alpha=alpha,
- stopThr=1e-10, method=method,
+ reg_m=reg_m,
+ method=method,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
-
- u_final = (a[:, None] / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
assert len(loss) == b.shape[1]
-def test_unbalanced_barycenter():
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -92,27 +132,56 @@ def test_unbalanced_barycenter():
A = A * np.array([1, 2])[None, :]
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10,
- log=True)
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
- u_final = (A / K.dot(log["v"])) ** fi
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
def test_implemented_methods():
- IMPLEMENTED_METHODS = ['sinkhorn']
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
- 'sinkhorn_epsilon_scaling']
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
# test generalized sinkhorn for unbalanced OT barycenter
n = 3
@@ -123,24 +192,30 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n) * 1.5
-
+ A = rng.rand(n, 2)
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
+ reg_m = 1.
for method in IMPLEMENTED_METHODS:
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.warns(UserWarning, match='not implemented'):
for method in set(TO_BE_IMPLEMENTED_METHODS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.raises(ValueError):
for method in set(NOT_VALID_TOKENS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)