summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-22 14:54:01 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-22 14:54:01 +0200
commit5c0ed104b2890c609bdadfe0fcb0e836ba7a6ef1 (patch)
tree548409d4a5ecf1d6ffe967fb48d57214463ff212 /test
parent10accb13c2f22c946b65b249d7aae6e4f6af7579 (diff)
add unbalanced tests with stabilization
Diffstat (limited to 'test')
-rw-r--r--test/test_unbalanced.py116
1 files changed, 77 insertions, 39 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 1395fe1..fc7aa5e 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -8,8 +8,10 @@ import numpy as np
import ot
import pytest
+from scipy.misc import logsumexp
-@pytest.mark.parametrize("method", ["sinkhorn"])
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -23,29 +25,34 @@ def test_unbalanced_convergence(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ mu = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
stopThr=1e-10, method=method,
log=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
method=method)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
- u_final = (a / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = mu / (mu + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
-@pytest.mark.parametrize("method", ["sinkhorn"])
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_multiple_inputs(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ mu = 1.
- loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- alpha=alpha,
+ loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu,
stopThr=1e-10, method=method,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
-
- u_final = (a[:, None] / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = mu / (mu + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
assert len(loss) == b.shape[1]
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ mu = 1.
+ G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon,
+ mu=mu,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)
+
+
def test_unbalanced_barycenter():
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
@@ -92,27 +127,30 @@ def test_unbalanced_barycenter():
A = A * np.array([1, 2])[None, :]
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ mu = 1.
- q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
+ q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu,
stopThr=1e-10,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
- u_final = (A / K.dot(log["v"])) ** fi
+ fi = mu / (mu + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
def test_implemented_methods():
- IMPLEMENTED_METHODS = ['sinkhorn']
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
- 'sinkhorn_epsilon_scaling']
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
# test generalized sinkhorn for unbalanced OT barycenter
n = 3
@@ -126,21 +164,21 @@ def test_implemented_methods():
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
+ mu = 1.
for method in IMPLEMENTED_METHODS:
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
method=method)
with pytest.warns(UserWarning, match='not implemented'):
for method in set(TO_BE_IMPLEMENTED_METHODS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
method=method)
with pytest.raises(ValueError):
for method in set(NOT_VALID_TOKENS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu,
method=method)