summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-04-24 12:39:38 +0200
committerRémi Flamary <remi.flamary@gmail.com>2020-04-24 12:39:38 +0200
commit53b063ed6b6aa15d6cb103a9304bbd169678b2e9 (patch)
tree33f75e733e1f93c07a5d37f72f085c9722bf19c7 /test
parent46523dc0956fd17e709f958ebd351e748fca0a23 (diff)
better coverage options verbose and log
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py9
-rw-r--r--test/test_optim.py2
-rwxr-xr-xtest/test_partial.py26
-rw-r--r--test/test_stochastic.py8
-rw-r--r--test/test_unbalanced.py9
5 files changed, 42 insertions, 12 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index ec4388d..6aa4e08 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -57,6 +57,9 @@ def test_sinkhorn_empty():
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+ # test empty weights greenkhorn
+ ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
def test_sinkhorn_variants():
# test sinkhorn
@@ -124,7 +127,7 @@ def test_barycenter(method):
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
+ bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -152,9 +155,9 @@ def test_barycenter_stabilization():
reg = 1e-2
bar_stable = ot.bregman.barycenter(A, M, reg, weights,
method="sinkhorn_stabilized",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
np.testing.assert_allclose(bar, bar_stable)
diff --git a/test/test_optim.py b/test/test_optim.py
index aade36e..87b0268 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -38,7 +38,7 @@ def test_conditional_gradient():
def test_conditional_gradient2():
- n = 4000 # nb samples
+ n = 1000 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
diff --git a/test/test_partial.py b/test/test_partial.py
index 8b1ca89..5960e4e 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -9,6 +9,30 @@ import numpy as np
import scipy as sp
import ot
+def test_partial_wasserstein_lagrange():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ m = 0.5
+
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+
+
+
def test_partial_wasserstein():
@@ -32,7 +56,7 @@ def test_partial_wasserstein():
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True)
+ log=True, verbose=True)
# check constratints
np.testing.assert_equal(
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index f0f3fc8..8ddf485 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -70,8 +70,8 @@ def test_stochastic_asgd():
M = ot.dist(x, x)
- G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
- numItermax=numItermax)
+ G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
+ numItermax=numItermax, log=True)
# check constratints
np.testing.assert_allclose(
@@ -145,8 +145,8 @@ def test_stochastic_dual_sgd():
M = ot.dist(x, x)
- G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
- numItermax=numItermax)
+ G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
+ numItermax=numItermax, log=True)
# check constratints
np.testing.assert_allclose(
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index ca1efba..d5bae42 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -31,9 +31,11 @@ def test_unbalanced_convergence(method):
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
- log=True)
+ log=True,
+ verbose=True)
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method)
+ method=method,
+ verbose=True)
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
@@ -73,7 +75,8 @@ def test_unbalanced_multiple_inputs(method):
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
- log=True)
+ log=True,
+ verbose=True)
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)