diff options
author | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:51:53 +0200 |
---|---|---|
committer | Hicham Janati <hicham.janati@inria.fr> | 2019-07-23 21:51:53 +0200 |
commit | a507556b1901e16351c211e69b38d8d74ac2bc3d (patch) | |
tree | 96f505830c031022580d32d08d1d2081e9e45204 /test | |
parent | a725f1dc0ac63ac919461ab8f2a23b111a410c00 (diff) |
rebase unbalanced
Diffstat (limited to 'test')
-rw-r--r-- | test/test_unbalanced.py | 116 |
1 files changed, 39 insertions, 77 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index fc7aa5e..1395fe1 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -8,10 +8,8 @@ import numpy as np import ot import pytest -from scipy.misc import logsumexp - -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_convergence(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -25,34 +23,29 @@ def test_unbalanced_convergence(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha, stopThr=1e-10, method=method, log=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) - - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + u_final = (a / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_unbalanced_multiple_inputs(method): # test generalized sinkhorn for unbalanced OT n = 100 @@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, mu=mu, + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + alpha=alpha, stopThr=1e-10, method=method, log=True) # check fixed point equations - # in log-domain - fi = mu / (mu + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logb - logKtu) - u_final = fi * (loga - logKv) + fi = alpha / (alpha + epsilon) + v_final = (b / K.T.dot(log["u"])) ** fi + + u_final = (a[:, None] / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): - # test if stable version matches sinkhorn - n = 100 - - # Gaussian distributions - a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std - b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) - b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) - - # creating matrix A containing all distributions - b = np.vstack((b1, b2)).T - - M = ot.utils.dist0(n) - M /= np.median(M) - epsilon = 0.1 - mu = 1. - G, log = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg=epsilon, - mu=mu, - log=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, - method="sinkhorn", log=True) - - np.testing.assert_allclose(G, G2) - - def test_unbalanced_barycenter(): # test generalized sinkhorn for unbalanced OT barycenter n = 100 @@ -127,30 +92,27 @@ def test_unbalanced_barycenter(): A = A * np.array([1, 2])[None, :] M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. + K = np.exp(- M / epsilon) - q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, mu=mu, + q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha, stopThr=1e-10, log=True) # check fixed point equations - fi = mu / (mu + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) - v_final = fi * (logq - logKtu) - u_final = fi * (logA - logKv) + fi = alpha / (alpha + epsilon) + v_final = (q[:, None] / K.T.dot(log["u"])) ** fi + u_final = (A / K.dot(log["v"])) ** fi np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + u_final, log["u"], atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + v_final, log["v"], atol=1e-05) def test_implemented_methods(): - IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] - TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] + IMPLEMENTED_METHODS = ['sinkhorn'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized', + 'sinkhorn_epsilon_scaling'] NOT_VALID_TOKENS = ['foo'] # test generalized sinkhorn for unbalanced OT barycenter n = 3 @@ -164,21 +126,21 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. - mu = 1. + alpha = 1. for method in IMPLEMENTED_METHODS: - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.warns(UserWarning, match='not implemented'): for method in set(TO_BE_IMPLEMENTED_METHODS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) with pytest.raises(ValueError): for method in set(NOT_VALID_TOKENS): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, method=method) - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, mu, + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha, method=method) |