From 53b063ed6b6aa15d6cb103a9304bbd169678b2e9 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 24 Apr 2020 12:39:38 +0200 Subject: better coverage options verbose and log --- test/test_bregman.py | 9 ++++++--- test/test_optim.py | 2 +- test/test_partial.py | 26 +++++++++++++++++++++++++- test/test_stochastic.py | 8 ++++---- test/test_unbalanced.py | 9 ++++++--- 5 files changed, 42 insertions(+), 12 deletions(-) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index ec4388d..6aa4e08 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -57,6 +57,9 @@ def test_sinkhorn_empty(): np.testing.assert_allclose(u, G.sum(1), atol=1e-05) np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + # test empty weights greenkhorn + ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True) + def test_sinkhorn_variants(): # test sinkhorn @@ -124,7 +127,7 @@ def test_barycenter(method): # wasserstein reg = 1e-2 - bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method) + bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) np.testing.assert_allclose(1, np.sum(bary_wass)) @@ -152,9 +155,9 @@ def test_barycenter_stabilization(): reg = 1e-2 bar_stable = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn_stabilized", - stopThr=1e-8) + stopThr=1e-8, verbose=True) bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", - stopThr=1e-8) + stopThr=1e-8, verbose=True) np.testing.assert_allclose(bar, bar_stable) diff --git a/test/test_optim.py b/test/test_optim.py index aade36e..87b0268 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -38,7 +38,7 @@ def test_conditional_gradient(): def test_conditional_gradient2(): - n = 4000 # nb samples + n = 1000 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) diff --git a/test/test_partial.py b/test/test_partial.py index 8b1ca89..5960e4e 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -9,6 +9,30 @@ import numpy as np import scipy as sp import ot +def test_partial_wasserstein_lagrange(): + + n_samples = 20 # nb samples (gaussian) + n_noise = 20 # nb of samples (noise) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2)) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2)) + + M = ot.dist(xs, xt) + + p = ot.unif(n_samples + n_noise) + q = ot.unif(n_samples + n_noise) + + m = 0.5 + + w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True) + + + def test_partial_wasserstein(): @@ -32,7 +56,7 @@ def test_partial_wasserstein(): w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, - log=True) + log=True, verbose=True) # check constratints np.testing.assert_equal( diff --git a/test/test_stochastic.py b/test/test_stochastic.py index f0f3fc8..8ddf485 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -70,8 +70,8 @@ def test_stochastic_asgd(): M = ot.dist(x, x) - G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", - numItermax=numItermax) + G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd", + numItermax=numItermax, log=True) # check constratints np.testing.assert_allclose( @@ -145,8 +145,8 @@ def test_stochastic_dual_sgd(): M = ot.dist(x, x) - G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, - numItermax=numItermax) + G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size, + numItermax=numItermax, log=True) # check constratints np.testing.assert_allclose( diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index ca1efba..d5bae42 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -31,9 +31,11 @@ def test_unbalanced_convergence(method): G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, - log=True) + log=True, + verbose=True) loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method) + method=method, + verbose=True) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) @@ -73,7 +75,8 @@ def test_unbalanced_multiple_inputs(method): loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, - log=True) + log=True, + verbose=True) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) -- cgit v1.2.3