summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /test/test_unbalanced.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py207
1 files changed, 146 insertions, 61 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index e8349d1..02b3fc3 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -1,6 +1,7 @@
"""Tests for module Unbalanced OT with entropy regularization"""
# Author: Hicham Janati <hicham.janati@inria.fr>
+# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#
# License: MIT License
@@ -9,11 +10,9 @@ import ot
import pytest
from ot.unbalanced import barycenter_unbalanced
-from scipy.special import logsumexp
-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_convergence(method):
+def test_unbalanced_convergence(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -28,36 +27,51 @@ def test_unbalanced_convergence(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
log=True,
verbose=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method,
- verbose=True)
+ loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method=method, verbose=True
+ ))
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)
- logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
- logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)
+ logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
- np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5)
+
+ # check in case no histogram is provided
+ M_np = nx.to_numpy(M)
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G = ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ G_np = ot.unbalanced.sinkhorn_unbalanced(
+ a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ np.testing.assert_allclose(G_np, nx.to_numpy(G))
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_multiple_inputs(method):
+def test_unbalanced_multiple_inputs(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -72,6 +86,8 @@ def test_unbalanced_multiple_inputs(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
@@ -80,23 +96,24 @@ def test_unbalanced_multiple_inputs(method):
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
assert len(loss) == b.shape[1]
-def test_stabilized_vs_sinkhorn():
+def test_stabilized_vs_sinkhorn(nx):
# test if stable version matches sinkhorn
n = 100
@@ -112,19 +129,27 @@ def test_stabilized_vs_sinkhorn():
M /= np.median(M)
epsilon = 0.1
reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
- method="sinkhorn_stabilized",
- reg_m=reg_m,
- log=True,
- verbose=True)
- G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method="sinkhorn", log=True)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ G, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True
+ )
+ G2, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G = nx.to_numpy(G)
+ G2 = nx.to_numpy(G2)
np.testing.assert_allclose(G, G2, atol=1e-5)
+ np.testing.assert_allclose(G2, G2_np, atol=1e-5)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_barycenter(method):
+def test_unbalanced_barycenter(nx, method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -138,25 +163,29 @@ def test_unbalanced_barycenter(method):
epsilon = 1.
reg_m = 1.
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True, verbose=True)
+ A, M = nx.from_numpy(A, M)
+
+ q, log = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True
+ )
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
- logA = np.log(A + 1e-16)
- logq = np.log(q + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logA = nx.log(A + 1e-16)
+ logq = nx.log(q + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logq - logKtu)
u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
-def test_barycenter_stabilized_vs_sinkhorn():
+def test_barycenter_stabilized_vs_sinkhorn(nx):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -170,21 +199,24 @@ def test_barycenter_stabilized_vs_sinkhorn():
epsilon = 0.5
reg_m = 10
- qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
- reg_m=reg_m, log=True,
- tau=100,
- method="sinkhorn_stabilized",
- verbose=True
- )
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method="sinkhorn",
- log=True)
+ Ab, Mb = nx.from_numpy(A, M)
- np.testing.assert_allclose(
- q, qstable, atol=1e-05)
+ qstable, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100,
+ method="sinkhorn_stabilized", verbose=True
+ )
+ q, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q_np, _ = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q, qstable = nx.to_numpy(q, qstable)
+ np.testing.assert_allclose(q, qstable, atol=1e-05)
+ np.testing.assert_allclose(q, q_np, atol=1e-05)
-def test_wrong_method():
+def test_wrong_method(nx):
n = 10
rng = np.random.RandomState(42)
@@ -199,19 +231,20 @@ def test_wrong_method():
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- reg_m=reg_m,
- method='badmethod',
- log=True,
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod',
+ log=True, verbose=True
+ )
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method='badmethod',
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method='badmethod', verbose=True
+ )
-def test_implemented_methods():
+def test_implemented_methods(nx):
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
@@ -228,6 +261,9 @@ def test_implemented_methods():
M = ot.dist(x, x)
epsilon = 1.
reg_m = 1.
+
+ a, b, M, A = nx.from_numpy(a, b, M, A)
+
for method in IMPLEMENTED_METHODS:
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
@@ -251,3 +287,52 @@ def test_implemented_methods():
method=method)
barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method=method)
+
+
+def test_mm_convergence(nx):
+ n = 100
+ rng = np.random.RandomState(42)
+ x = rng.randn(n, 2)
+ rng = np.random.RandomState(75)
+ y = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+ b = ot.utils.unif(n)
+
+ M = ot.dist(x, y)
+ M = M / M.max()
+ reg_m = 100
+ a, b, M = nx.from_numpy(a, b, M)
+
+ G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
+ verbose=True, log=True)
+ loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
+ a, b, M, reg_m, div='kl', verbose=True))
+ G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
+ verbose=False, log=True)
+
+ # check if the marginals come close to the true ones when large reg
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
+ np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
+
+ # check if mm_unbalanced2 returns the correct loss
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
+ atol=1e-5)
+
+ # check in case no histogram is provided
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
+ G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
+ np.testing.assert_allclose(G_kl_null, G_kl)
+ np.testing.assert_allclose(G_l2_null, G_l2)
+
+ # test when G0 is given
+ G0 = ot.emd(a, b, M)
+ reg_m = 10000
+ G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
+ G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
+ np.testing.assert_allclose(G0, G_kl, atol=1e-05)
+ np.testing.assert_allclose(G0, G_l2, atol=1e-05)