summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py42
-rw-r--r--test/test_da.py210
-rw-r--r--test/test_gromov.py4
-rw-r--r--test/test_optim.py33
-rw-r--r--test/test_ot.py67
-rwxr-xr-xtest/test_partial.py208
-rw-r--r--test/test_stochastic.py8
-rw-r--r--test/test_unbalanced.py9
-rw-r--r--test/test_utils.py4
9 files changed, 557 insertions, 28 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index f70df10..6aa4e08 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -57,6 +57,9 @@ def test_sinkhorn_empty():
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+ # test empty weights greenkhorn
+ ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
def test_sinkhorn_variants():
# test sinkhorn
@@ -106,7 +109,6 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_barycenter(method):
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -125,7 +127,7 @@ def test_barycenter(method):
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
+ bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -133,7 +135,6 @@ def test_barycenter(method):
def test_barycenter_stabilization():
-
n_bins = 100 # nb bins
# Gaussian distributions
@@ -154,14 +155,13 @@ def test_barycenter_stabilization():
reg = 1e-2
bar_stable = ot.bregman.barycenter(A, M, reg, weights,
method="sinkhorn_stabilized",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8)
+ stopThr=1e-8, verbose=True)
np.testing.assert_allclose(bar, bar_stable)
def test_wasserstein_bary_2d():
-
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -185,7 +185,6 @@ def test_wasserstein_bary_2d():
def test_unmix():
-
n_bins = 50 # nb bins
# Gaussian distributions
@@ -207,7 +206,7 @@ def test_unmix():
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -256,7 +255,7 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
- #Test sinkhorn divergence
+ # Test sinkhorn divergence
n = 10
a = ot.unif(n)
b = ot.unif(n)
@@ -337,3 +336,28 @@ def test_implemented_methods():
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
with pytest.raises(ValueError):
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+
+
+def test_screenkhorn():
+ # test screenkhorn
+ rng = np.random.RandomState(0)
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ x = rng.randn(n, 2)
+ M = ot.dist(x, x)
+ # sinkhorn
+ G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ # screenkhorn
+ G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ # check marginals
+ np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
+
+
+def test_convolutional_barycenter_non_square():
+ # test for image with height not equal width
+ A = np.ones((2, 2, 3)) / (2 * 3)
+ b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
diff --git a/test/test_da.py b/test/test_da.py
index 2a5e50e..3b28119 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -5,7 +5,7 @@
# License: MIT License
import numpy as np
-from numpy.testing.utils import assert_allclose, assert_equal
+from numpy.testing import assert_allclose, assert_equal
import ot
from ot.datasets import make_data_classif
@@ -65,6 +65,16 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -129,6 +139,16 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -210,6 +230,16 @@ def test_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -271,6 +301,16 @@ def test_unbalanced_sinkhorn_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+
Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
@@ -353,6 +393,16 @@ def test_emd_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+
Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
@@ -510,7 +560,6 @@ def test_mapping_transport_class():
def test_linear_mapping():
-
ns = 150
nt = 200
@@ -528,7 +577,6 @@ def test_linear_mapping():
def test_linear_mapping_class():
-
ns = 150
nt = 200
@@ -549,3 +597,159 @@ def test_linear_mapping_class():
Cst = np.cov(Xst.T)
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_jcpot_transport_class():
+ """test_jcpot_transport
+ """
+
+ ns1 = 150
+ ns2 = 150
+ nt = 200
+
+ Xs1, ys1 = make_data_classif('3gauss', ns1)
+ Xs2, ys2 = make_data_classif('3gauss', ns2)
+
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs = [Xs1, Xs2]
+ ys = [ys1, ys2]
+
+ otda = ot.da.JCPOTTransport(reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "proportions_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ for i, xs in enumerate(Xs):
+ assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0])))
+
+ # test all margin constraints
+ mu_t = unif(nt)
+
+ for i in range(len(Xs)):
+ # test margin constraints w.r.t. uniform target weights for each coupling matrix
+ assert_allclose(
+ np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3)
+
+ # test margin constraints w.r.t. modified source weights for each source domain
+
+ assert_allclose(
+ np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
+ atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
+
+ Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)]
+ [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)]
+
+
+def test_jcpot_barycenter():
+ """test_jcpot_barycenter
+ """
+
+ ns1 = 150
+ ns2 = 150
+ nt = 200
+
+ sigma = 0.1
+ np.random.seed(1985)
+
+ ps1 = .2
+ ps2 = .9
+ pt = .4
+
+ Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1)
+ Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2)
+ Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
+
+ Xs = [Xs1, Xs2]
+ ys = [ys1, ys2]
+
+ prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean',
+ numItermax=10000, stopThr=1e-9, verbose=False, log=False)
+
+ np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+
+
+def test_emd_laplace_class():
+ """test_emd_laplace_transport
+ """
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test all margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # check label propagation
+ transp_yt = otda.transform_labels(ys)
+ assert_equal(transp_yt.shape[0], yt.shape[0])
+ assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+
+ # check inverse label propagation
+ transp_ys = otda.inverse_transform_labels(yt)
+ assert_equal(transp_ys.shape[0], ys.shape[0])
+ assert_equal(transp_ys.shape[1], len(np.unique(yt)))
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 70fa83f..43da9fc 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -44,10 +44,14 @@ def test_gromov():
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+
G = log['T']
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+
# check constratints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
diff --git a/test/test_optim.py b/test/test_optim.py
index ae31e1f..87b0268 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -37,6 +37,39 @@ def test_conditional_gradient():
np.testing.assert_allclose(b, G.sum(0))
+def test_conditional_gradient2():
+ n = 1000 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+ xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+ a, b = np.ones((n,)) / n, np.ones((n,)) / n
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg = 1e-1
+
+ G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
+ verbose=True, log=True)
+
+ np.testing.assert_allclose(a, G.sum(1))
+ np.testing.assert_allclose(b, G.sum(0))
+
+
def test_generalized_conditional_gradient():
n_bins = 100 # nb bins
diff --git a/test/test_ot.py b/test/test_ot.py
index dacae0a..b7306f6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,11 +7,27 @@
import warnings
import numpy as np
+import pytest
from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
-import pytest
+
+
+def test_emd_dimension_mismatch():
+ # test emd and emd2 for dimension mismatch
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples + 1)
+
+ M = ot.dist(x, x)
+
+ np.testing.assert_raises(AssertionError, ot.emd, a, a, M)
+
+ np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
def test_emd_emd2():
@@ -59,12 +75,12 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(wass, wass1d_emd2)
# check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)
# check constraints
- np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
np.testing.assert_allclose(G, G_1d)
@@ -76,6 +92,42 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])
+def test_emd_1d_emd2_1d_with_weights():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(w_u, G.sum(1))
+ np.testing.assert_allclose(w_v, G.sum(0))
+
+
def test_wass_1d():
# test emd1d gives similar results as emd
n = 20
@@ -168,7 +220,6 @@ def test_emd2_multi():
def test_lp_barycenter():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]
@@ -185,7 +236,6 @@ def test_lp_barycenter():
def test_free_support_barycenter():
-
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
measures_weights = [np.array([1.]), np.array([1.])]
@@ -201,7 +251,6 @@ def test_free_support_barycenter():
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
-
a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]
@@ -295,6 +344,10 @@ def test_dual_variables():
np.testing.assert_almost_equal(cost1, log['cost'])
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
+ constraint_violation = log['u'][:, None] + log['v'][None, :] - M
+
+ assert constraint_violation.max() < 1e-8
+
def check_duality_gap(a, b, M, G, u, v, cost):
cost_dual = np.vdot(a, u) + np.vdot(b, v)
diff --git a/test/test_partial.py b/test/test_partial.py
new file mode 100755
index 0000000..510e081
--- /dev/null
+++ b/test/test_partial.py
@@ -0,0 +1,208 @@
+"""Tests for module partial """
+
+# Author:
+# Laetitia Chapel <laetitia.chapel@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy as sp
+import ot
+import pytest
+
+
+def test_raise_errors():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ with pytest.raises(ValueError):
+ ot.partial.partial_wasserstein_lagrange(p + 1, q, M, 1, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.partial_wasserstein(p, q, M, m=2, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.partial_wasserstein(p, q, M, m=-1, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=2, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=-1, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.partial_gromov_wasserstein(M, M, p, q, m=2, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True)
+
+ with pytest.raises(ValueError):
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True)
+
+
+def test_partial_wasserstein_lagrange():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+
+
+def test_partial_wasserstein():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ m = 0.5
+
+ w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
+ log=True, verbose=True)
+
+ # check constratints
+ np.testing.assert_equal(
+ w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+
+ # check transported mass
+ np.testing.assert_allclose(
+ np.sum(w0), m, atol=1e-04)
+ np.testing.assert_allclose(
+ np.sum(w), m, atol=1e-04)
+
+ w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+ w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
+
+ G = log0['T']
+
+ np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_equal(
+ G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(G), m, atol=1e-04)
+
+
+def test_partial_gromov_wasserstein():
+ n_samples = 20 # nb samples
+ n_noise = 10 # nb of samples (noise)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([0, 0, 0])
+ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ P = sp.linalg.sqrtm(cov_t)
+ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt2 = xs[::-1].copy()
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ C3 = ot.dist(xt2, xt2)
+
+ m = 2 / 3
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C3, p, q, m=m,
+ log=True, verbose=True)
+ np.testing.assert_allclose(res0, 0, atol=1e-1, rtol=1e-1)
+
+ C1 = sp.spatial.distance.cdist(xs, xs)
+ C2 = sp.spatial.distance.cdist(xt, xt)
+
+ m = 1
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
+ log=True)
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
+ np.testing.assert_allclose(G, res0, atol=1e-04)
+
+ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=10)
+ np.testing.assert_allclose(G, res, atol=1e-02)
+
+ w0, log0 = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
+ log=True)
+ w0_val = ot.partial.partial_gromov_wasserstein2(C1, C2, p, q, m=m,
+ log=False)
+ G = log0['T']
+ np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
+
+ m = 2 / 3
+ res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m,
+ log=True)
+ res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q,
+ 100, m=m,
+ log=True)
+
+ # check constratints
+ np.testing.assert_equal(
+ res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ res0.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(res0), m, atol=1e-04)
+
+ np.testing.assert_equal(
+ res.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ np.testing.assert_equal(
+ res.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_allclose(
+ np.sum(res), m, atol=1e-04)
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index f0f3fc8..155622c 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -70,8 +70,8 @@ def test_stochastic_asgd():
M = ot.dist(x, x)
- G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
- numItermax=numItermax)
+ G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
+ numItermax=numItermax, log=True)
# check constratints
np.testing.assert_allclose(
@@ -145,8 +145,8 @@ def test_stochastic_dual_sgd():
M = ot.dist(x, x)
- G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
- numItermax=numItermax)
+ G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
+ numItermax=numItermax, log=True)
# check constratints
np.testing.assert_allclose(
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index ca1efba..dfeaad9 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -31,9 +31,11 @@ def test_unbalanced_convergence(method):
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
- log=True)
+ log=True,
+ verbose=True)
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method)
+ method=method,
+ verbose=True)
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
@@ -73,7 +75,8 @@ def test_unbalanced_multiple_inputs(method):
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
- log=True)
+ log=True,
+ verbose=True)
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
diff --git a/test/test_utils.py b/test/test_utils.py
index 640598d..db9cda6 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -36,10 +36,10 @@ def test_tic_toc():
t2 = ot.toq()
# test timing
- np.testing.assert_allclose(0.5, t, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1)
# test toc vs toq
- np.testing.assert_allclose(t, t2, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1)
def test_kernel():