summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py205
-rw-r--r--test/test_da.py65
-rw-r--r--test/test_gpu.py10
-rw-r--r--test/test_gromov.py119
-rw-r--r--test/test_optim.py39
-rw-r--r--test/test_ot.py112
-rw-r--r--test/test_unbalanced.py221
7 files changed, 747 insertions, 24 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 14edaf5..f54ba9f 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -1,11 +1,13 @@
"""Tests for module bregman on OT with bregman projections """
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
import numpy as np
import ot
+import pytest
def test_sinkhorn():
@@ -70,19 +72,40 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
-def test_bary():
+def test_sinkhorn_variants_log():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Ges, loges = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -100,16 +123,42 @@ def test_bary():
weights = np.array([1 - alpha, alpha])
# wasserstein
- reg = 1e-3
- bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
-def test_wasserstein_bary_2d():
+def test_barycenter_stabilization():
+ n_bins = 100 # nb bins
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
+def test_wasserstein_bary_2d():
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -133,7 +182,6 @@ def test_wasserstein_bary_2d():
def test_unmix():
-
n_bins = 50 # nb bins
# Gaussian distributions
@@ -155,10 +203,151 @@ def test_unmix():
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
+
+
+def test_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence():
+ # Test sinkhorn divergence
+ n = 10
+ a = ot.unif(n)
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_s = ot.dist(X_s, X_s)
+ M_t = ot.dist(X_t, X_t)
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
+ sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
+
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
+ sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
+ sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
+ sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
+
+ # check constratints
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ np.testing.assert_allclose(
+ emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+
+
+def test_screenkhorn():
+ # test screenkhorn
+ rng = np.random.RandomState(0)
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ x = rng.randn(n, 2)
+ M = ot.dist(x, x)
+ # sinkhorn
+ G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ # screenkhorn
+ G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ # check marginals
+ np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
+ np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..2a5e50e 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,71 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""
diff --git a/test/test_gpu.py b/test/test_gpu.py
index 6b7fdd4..8e62a74 100644
--- a/test/test_gpu.py
+++ b/test/test_gpu.py
@@ -16,6 +16,16 @@ except ImportError:
@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_old_doctests():
+ a = [.5, .5]
+ b = [.5, .5]
+ M = [[0., 1.], [1., 0.]]
+ G = ot.sinkhorn(a, b, M, 1)
+ np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]]))
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
def test_gpu_dist():
rng = np.random.RandomState(0)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 305ae84..43da9fc 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -2,6 +2,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
@@ -15,7 +16,7 @@ def test_gromov():
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
xt = xs[::-1].copy()
@@ -36,12 +37,21 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04)
+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+
G = log['T']
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+
# check constratints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
@@ -55,7 +65,7 @@ def test_entropic_gromov():
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
xt = xs[::-1].copy()
@@ -92,12 +102,11 @@ def test_entropic_gromov():
def test_gromov_barycenter():
-
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -120,12 +129,11 @@ def test_gromov_barycenter():
def test_gromov_entropic_barycenter():
-
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -145,3 +153,98 @@ def test_gromov_entropic_barycenter():
'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_fgw():
+
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw_barycenter():
+ np.random.seed(42)
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ ys = np.random.randn(Xs.shape[0], 2)
+ yt = np.random.randn(Xt.shape[0], 2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X = np.random.randn(n_samples, ys.shape[1])
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py
index dfefe59..aade36e 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -37,6 +37,39 @@ def test_conditional_gradient():
np.testing.assert_allclose(b, G.sum(0))
+def test_conditional_gradient2():
+ n = 4000 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+ xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+ a, b = np.ones((n,)) / n, np.ones((n,)) / n
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg = 1e-1
+
+ G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
+ verbose=True, log=True)
+
+ np.testing.assert_allclose(a, G.sum(1))
+ np.testing.assert_allclose(b, G.sum(0))
+
+
def test_generalized_conditional_gradient():
n_bins = 100 # nb bins
@@ -65,3 +98,9 @@ def test_generalized_conditional_gradient():
np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_solve_1d_linesearch_quad_funct():
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
diff --git a/test/test_ot.py b/test/test_ot.py
index 7652394..47df946 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -7,20 +7,27 @@
import warnings
import numpy as np
+from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
import pytest
-def test_doctest():
- import doctest
+def test_emd_dimension_mismatch():
+ # test emd and emd2 for dimension mismatch
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples + 1)
- # test lp solver
- doctest.testmod(ot.lp, verbose=True)
+ M = ot.dist(x, x)
- # test bregman solver
- doctest.testmod(ot.bregman, verbose=True)
+ np.testing.assert_raises(AssertionError, ot.emd, a, a, M)
+
+ np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
def test_emd_emd2():
@@ -37,7 +44,7 @@ def test_emd_emd2():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -46,6 +53,64 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_wass_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+
+ wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+
+ # check loss is similar
+ np.testing.assert_allclose(np.sqrt(wass), wass1d)
+
+
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
@@ -60,7 +125,7 @@ def test_emd_empty():
# check G is identity
np.testing.assert_allclose(G, np.eye(n) / n)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
@@ -69,6 +134,28 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)
+def test_emd_sparse():
+
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x2 = rng.randn(n, 2)
+
+ M = ot.dist(x, x2)
+
+ G = ot.emd([], [], M, dense=True)
+
+ Gs = ot.emd([], [], M, dense=False)
+
+ ws = ot.emd2([], [], M, dense=False)
+
+ # check G is the same
+ np.testing.assert_allclose(G, Gs.todense())
+ # check value
+ np.testing.assert_allclose(Gs.multiply(M).sum(), ws, rtol=1e-6)
+
+
def test_emd2_multi():
n = 500 # nb bins
@@ -100,7 +187,12 @@ def test_emd2_multi():
emdn = ot.emd2(a, b, M)
ot.toc('multi proc : {} s')
+ ot.tic()
+ emdn2 = ot.emd2(a, b, M, dense=False)
+ ot.toc('multi proc : {} s')
+
np.testing.assert_allclose(emd1, emdn)
+ np.testing.assert_allclose(emd1, emdn2, rtol=1e-6)
# emd loss multipro proc with log
ot.tic()
@@ -246,6 +338,10 @@ def test_dual_variables():
np.testing.assert_almost_equal(cost1, log['cost'])
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
+ constraint_violation = log['u'][:, None] + log['v'][None, :] - M
+
+ assert constraint_violation.max() < 1e-8
+
def check_duality_gap(a, b, M, G, u, v, cost):
cost_dual = np.vdot(a, u) + np.vdot(b, v)
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
new file mode 100644
index 0000000..ca1efba
--- /dev/null
+++ b/test/test_unbalanced.py
@@ -0,0 +1,221 @@
+"""Tests for module Unbalanced OT with entropy regularization"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+from ot.unbalanced import barycenter_unbalanced
+
+from scipy.special import logsumexp
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_convergence(method):
+ # test generalized sinkhorn for unbalanced OT
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
+ log=True)
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ # check fixed point equations
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+ # check if sinkhorn_unbalanced2 returns the correct loss
+ np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_multiple_inputs(method):
+ # test generalized sinkhorn for unbalanced OT
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = rng.rand(n, 2)
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
+ log=True)
+ # check fixed point equations
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+ assert len(loss) == b.shape[1]
+
+
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 2])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
+ # check fixed point equations
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+ for method in IMPLEMENTED_METHODS:
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
+ with pytest.warns(UserWarning, match='not implemented'):
+ for method in set(TO_BE_IMPLEMENTED_METHODS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)