diff options
Diffstat (limited to 'test/test_unbalanced.py')
-rw-r--r-- | test/test_unbalanced.py | 157 |
1 files changed, 96 insertions, 61 deletions
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e8349d1..db59504 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,11 +9,9 @@ import ot import pytest from ot.unbalanced import barycenter_unbalanced -from scipy.special import logsumexp - @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(method): +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -28,36 +26,51 @@ def test_unbalanced_convergence(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method, - verbose=True) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method=method, verbose=True + )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16) + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss - np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_multiple_inputs(method): +def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, @@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method): # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): +def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn n = 100 @@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn(): M /= np.median(M) epsilon = 0.1 reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - method="sinkhorn_stabilized", - reg_m=reg_m, - log=True, - verbose=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method="sinkhorn", log=True) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ) + G2, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ) + G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method="sinkhorn", log=True + ) + G = nx.to_numpy(G) + G2 = nx.to_numpy(G2) np.testing.assert_allclose(G, G2, atol=1e-5) + np.testing.assert_allclose(G2, G2_np, atol=1e-5) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_barycenter(method): +def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method): epsilon = 1. reg_m = 1. - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True) + A, M = nx.from_numpy(A, M) + + q, log = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations fi = reg_m / (reg_m + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logA = nx.log(A + 1e-16) + logq = nx.log(q + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) -def test_barycenter_stabilized_vs_sinkhorn(): +def test_barycenter_stabilized_vs_sinkhorn(nx): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn(): epsilon = 0.5 reg_m = 10 - qstable, log = barycenter_unbalanced(A, M, reg=epsilon, - reg_m=reg_m, log=True, - tau=100, - method="sinkhorn_stabilized", - verbose=True - ) - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method="sinkhorn", - log=True) + Ab, Mb = nx.from_numpy(A, M) - np.testing.assert_allclose( - q, qstable, atol=1e-05) + qstable, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, + method="sinkhorn_stabilized", verbose=True + ) + q, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q_np, _ = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q, qstable = nx.to_numpy(q, qstable) + np.testing.assert_allclose(q, qstable, atol=1e-05) + np.testing.assert_allclose(q, q_np, atol=1e-05) -def test_wrong_method(): +def test_wrong_method(nx): n = 10 rng = np.random.RandomState(42) @@ -199,19 +230,20 @@ def test_wrong_method(): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method='badmethod', - log=True, - verbose=True) + ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', + log=True, verbose=True + ) with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method='badmethod', - verbose=True) + ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method='badmethod', verbose=True + ) -def test_implemented_methods(): +def test_implemented_methods(nx): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] @@ -228,6 +260,9 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. reg_m = 1. + + a, b, M, A = nx.from_numpy(a, b, M, A) + for method in IMPLEMENTED_METHODS: ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) |