diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
commit | bdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch) | |
tree | 00cbac5f3dc25a4ee76164828abd72c1cbab37cc /test/test_unbalanced.py | |
parent | abc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff) | |
parent | 5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff) |
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 221 |
1 files changed, 221 insertions, 0 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py new file mode 100644 index 0000000..ca1efba --- /dev/null +++ b/test/test_unbalanced.py @@ -0,0 +1,221 @@ +"""Tests for module Unbalanced OT with entropy regularization""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import ot +import pytest +from ot.unbalanced import barycenter_unbalanced + +from scipy.special import logsumexp + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_convergence(method): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method=method, + log=True) + loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method=method) + # check fixed point equations + # in log-domain + fi = reg_m / (reg_m + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16) + logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) + + np.testing.assert_allclose( + u_final, log["logu"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["logv"], atol=1e-05) + + # check if sinkhorn_unbalanced2 returns the correct loss + np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_multiple_inputs(method): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, + reg_m=reg_m, + method=method, + log=True) + # check fixed point equations + # in log-domain + fi = reg_m / (reg_m + epsilon) + logb = np.log(b + 1e-16) + loga = np.log(a + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logb - logKtu) + u_final = fi * (loga - logKv) + + np.testing.assert_allclose( + u_final, log["logu"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["logv"], atol=1e-05) + + assert len(loss) == b.shape[1] + + +def test_stabilized_vs_sinkhorn(): + # test if stable version matches sinkhorn + n = 100 + + # Gaussian distributions + a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std + b1 = ot.datasets.make_1D_gauss(n, m=60, s=8) + b2 = ot.datasets.make_1D_gauss(n, m=30, s=4) + + # creating matrix A containing all distributions + b = np.vstack((b1, b2)).T + + M = ot.utils.dist0(n) + M /= np.median(M) + epsilon = 0.1 + reg_m = 1. + G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, + method="sinkhorn_stabilized", + reg_m=reg_m, + log=True) + G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method="sinkhorn", log=True) + + np.testing.assert_allclose(G, G2, atol=1e-5) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_unbalanced_barycenter(method): + # test generalized sinkhorn for unbalanced OT barycenter + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + A = rng.rand(n, 2) + + # make dists unbalanced + A = A * np.array([1, 2])[None, :] + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + + q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method, log=True) + # check fixed point equations + fi = reg_m / (reg_m + epsilon) + logA = np.log(A + 1e-16) + logq = np.log(q + 1e-16)[:, None] + logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, + axis=0) + logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + v_final = fi * (logq - logKtu) + u_final = fi * (logA - logKv) + + np.testing.assert_allclose( + u_final, log["logu"], atol=1e-05) + np.testing.assert_allclose( + v_final, log["logv"], atol=1e-05) + + +def test_barycenter_stabilized_vs_sinkhorn(): + # test generalized sinkhorn for unbalanced OT barycenter + n = 100 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + A = rng.rand(n, 2) + + # make dists unbalanced + A = A * np.array([1, 4])[None, :] + M = ot.dist(x, x) + epsilon = 0.5 + reg_m = 10 + + qstable, log = barycenter_unbalanced(A, M, reg=epsilon, + reg_m=reg_m, log=True, + tau=100, + method="sinkhorn_stabilized", + ) + q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method="sinkhorn", + log=True) + + np.testing.assert_allclose( + q, qstable, atol=1e-05) + + +def test_implemented_methods(): + IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] + TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] + NOT_VALID_TOKENS = ['foo'] + # test generalized sinkhorn for unbalanced OT barycenter + n = 3 + rng = np.random.RandomState(42) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = ot.utils.unif(n) * 1.5 + A = rng.rand(n, 2) + M = ot.dist(x, x) + epsilon = 1. + reg_m = 1. + for method in IMPLEMENTED_METHODS: + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) + with pytest.warns(UserWarning, match='not implemented'): + for method in set(TO_BE_IMPLEMENTED_METHODS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) + with pytest.raises(ValueError): + for method in set(NOT_VALID_TOKENS): + ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, + method=method) + ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, + method=method) + barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, + method=method) |