summaryrefslogtreecommitdiff
path: root/test/test_unbalanced.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-03-24 10:53:47 +0100
committerGitHub <noreply@github.com>2022-03-24 10:53:47 +0100
commit767171593f2a98a26b9a39bf110a45085e3b982e (patch)
tree4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /test/test_unbalanced.py
parent9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff)
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r--test/test_unbalanced.py157
1 files changed, 96 insertions, 61 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index e8349d1..db59504 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -9,11 +9,9 @@ import ot
import pytest
from ot.unbalanced import barycenter_unbalanced
-from scipy.special import logsumexp
-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_convergence(method):
+def test_unbalanced_convergence(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -28,36 +26,51 @@ def test_unbalanced_convergence(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
log=True,
verbose=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method,
- verbose=True)
+ loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method=method, verbose=True
+ ))
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)
- logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
- logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)
+ logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
- np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5)
+
+ # check in case no histogram is provided
+ M_np = nx.to_numpy(M)
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G = ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ G_np = ot.unbalanced.sinkhorn_unbalanced(
+ a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ np.testing.assert_allclose(G_np, nx.to_numpy(G))
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_multiple_inputs(method):
+def test_unbalanced_multiple_inputs(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
@@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method):
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
assert len(loss) == b.shape[1]
-def test_stabilized_vs_sinkhorn():
+def test_stabilized_vs_sinkhorn(nx):
# test if stable version matches sinkhorn
n = 100
@@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn():
M /= np.median(M)
epsilon = 0.1
reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
- method="sinkhorn_stabilized",
- reg_m=reg_m,
- log=True,
- verbose=True)
- G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method="sinkhorn", log=True)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ G, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True
+ )
+ G2, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G = nx.to_numpy(G)
+ G2 = nx.to_numpy(G2)
np.testing.assert_allclose(G, G2, atol=1e-5)
+ np.testing.assert_allclose(G2, G2_np, atol=1e-5)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_barycenter(method):
+def test_unbalanced_barycenter(nx, method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method):
epsilon = 1.
reg_m = 1.
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True, verbose=True)
+ A, M = nx.from_numpy(A, M)
+
+ q, log = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True
+ )
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
- logA = np.log(A + 1e-16)
- logq = np.log(q + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logA = nx.log(A + 1e-16)
+ logq = nx.log(q + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logq - logKtu)
u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
-def test_barycenter_stabilized_vs_sinkhorn():
+def test_barycenter_stabilized_vs_sinkhorn(nx):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn():
epsilon = 0.5
reg_m = 10
- qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
- reg_m=reg_m, log=True,
- tau=100,
- method="sinkhorn_stabilized",
- verbose=True
- )
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method="sinkhorn",
- log=True)
+ Ab, Mb = nx.from_numpy(A, M)
- np.testing.assert_allclose(
- q, qstable, atol=1e-05)
+ qstable, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100,
+ method="sinkhorn_stabilized", verbose=True
+ )
+ q, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q_np, _ = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q, qstable = nx.to_numpy(q, qstable)
+ np.testing.assert_allclose(q, qstable, atol=1e-05)
+ np.testing.assert_allclose(q, q_np, atol=1e-05)
-def test_wrong_method():
+def test_wrong_method(nx):
n = 10
rng = np.random.RandomState(42)
@@ -199,19 +230,20 @@ def test_wrong_method():
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- reg_m=reg_m,
- method='badmethod',
- log=True,
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod',
+ log=True, verbose=True
+ )
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method='badmethod',
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method='badmethod', verbose=True
+ )
-def test_implemented_methods():
+def test_implemented_methods(nx):
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
@@ -228,6 +260,9 @@ def test_implemented_methods():
M = ot.dist(x, x)
epsilon = 1.
reg_m = 1.
+
+ a, b, M, A = nx.from_numpy(a, b, M, A)
+
for method in IMPLEMENTED_METHODS:
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)