summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py339
-rw-r--r--test/test_da.py551
-rw-r--r--test/test_dr.py59
-rw-r--r--test/test_gpu.py106
-rw-r--r--test/test_gromov.py246
-rw-r--r--test/test_optim.py73
-rw-r--r--test/test_ot.py308
-rw-r--r--test/test_plot.py60
-rw-r--r--test/test_smooth.py79
-rw-r--r--test/test_stochastic.py215
-rw-r--r--test/test_unbalanced.py221
-rw-r--r--test/test_utils.py202
12 files changed, 2459 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
new file mode 100644
index 0000000..f70df10
--- /dev/null
+++ b/test/test_bregman.py
@@ -0,0 +1,339 @@
+"""Tests for module bregman on OT with bregman projections """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+
+
+def test_sinkhorn():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+
+def test_sinkhorn_empty():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
+ # check constratints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
+ method='sinkhorn_stabilized', verbose=True, log=True)
+ # check constratints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
+ G, log = ot.sinkhorn(
+ [], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
+ verbose=True, log=True)
+ # check constratints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
+
+def test_sinkhorn_variants():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
+ Ges = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
+ G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
+
+
+def test_sinkhorn_variants_log():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Ges, loges = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+
+ ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+
+
+def test_barycenter_stabilization():
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
+def test_wasserstein_bary_2d():
+
+ size = 100 # size of a square image
+ a1 = np.random.randn(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.randn(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+
+ # wasserstein
+ reg = 1e-2
+ bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+def test_unmix():
+
+ n_bins = 50 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ a = ot.datasets.make_1D_gauss(n_bins, m=30, s=10)
+
+ # creating matrix A containing all distributions
+ D = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ M0 = ot.utils.dist0(2)
+ M0 /= M0.max()
+ h0 = ot.unif(2)
+
+ # wasserstein
+ reg = 1e-3
+ um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
+
+ np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+
+ ot.bregman.unmix(a, D, M, M0, h0, reg,
+ 1, alpha=0.01, log=True, verbose=True)
+
+
+def test_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence():
+ #Test sinkhorn divergence
+ n = 10
+ a = ot.unif(n)
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_s = ot.dist(X_s, X_s)
+ M_t = ot.dist(X_t, X_t)
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
+ sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
+
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
+ sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
+ sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
+ sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
+
+ # check constratints
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
+ np.testing.assert_allclose(
+ emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
diff --git a/test/test_da.py b/test/test_da.py
new file mode 100644
index 0000000..2a5e50e
--- /dev/null
+++ b/test/test_da.py
@@ -0,0 +1,551 @@
+"""Tests for module da on Domain Adaptation """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+from numpy.testing.utils import assert_allclose, assert_equal
+
+import ot
+from ot.datasets import make_data_classif
+from ot.utils import unif
+
+
+def test_sinkhorn_lpl1_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.SinkhornLpl1Transport()
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornLpl1Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornLpl1Transport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ assert mass_semi == 0, "semisupervised mode not working"
+
+
+def test_sinkhorn_l1l2_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.SinkhornL1l2Transport()
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornL1l2Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornL1l2Transport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-9, atol=1e-9)
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornL1l2Transport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
+def test_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.SinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ assert mass_semi == 0, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
+def test_emd_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.EMDTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.EMDTransport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.EMDTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+
+ # we need to use a small tolerance here, otherwise the test breaks
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-2, atol=1e-2)
+
+
+def test_mapping_transport_class():
+ """test_mapping_transport
+ """
+
+ ns = 60
+ nt = 120
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+
+ ##########################################################################
+ # kernel == linear mapping tests
+ ##########################################################################
+
+ # check computation and dimensions if bias == False
+ otda = ot.da.MappingTransport(kernel="linear", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "mapping_")
+ assert hasattr(otda, "log_")
+
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # check computation and dimensions if bias == True
+ otda = ot.da.MappingTransport(kernel="linear", bias=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ ##########################################################################
+ # kernel == gaussian mapping tests
+ ##########################################################################
+
+ # check computation and dimensions if bias == False
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # check computation and dimensions if bias == True
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # check everything runs well with log=True
+ otda = ot.da.MappingTransport(kernel="gaussian", log=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_linear_mapping_class():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otmap = ot.da.LinearTransport()
+
+ otmap.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otmap, "A_")
+ assert hasattr(otmap, "B_")
+ assert hasattr(otmap, "A1_")
+ assert hasattr(otmap, "B1_")
+
+ Xst = otmap.transform(Xs=Xs)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
diff --git a/test/test_dr.py b/test/test_dr.py
new file mode 100644
index 0000000..c5df287
--- /dev/null
+++ b/test/test_dr.py
@@ -0,0 +1,59 @@
+"""Tests for module dr on Dimensionality Reduction """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+
+try: # test if autograd and pymanopt are installed
+ import ot.dr
+ nogo = False
+except ImportError:
+ nogo = True
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_fda():
+
+ n_samples = 90 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 1
+
+ Pfda, projfda = ot.dr.fda(xs, ys, p)
+
+ projfda(xs)
+
+ np.testing.assert_allclose(np.sum(Pfda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10)
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
diff --git a/test/test_gpu.py b/test/test_gpu.py
new file mode 100644
index 0000000..8e62a74
--- /dev/null
+++ b/test/test_gpu.py
@@ -0,0 +1,106 @@
+"""Tests for module gpu for gpu acceleration """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+
+try: # test if cudamat installed
+ import ot.gpu
+ nogpu = False
+except ImportError:
+ nogpu = True
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_old_doctests():
+ a = [.5, .5]
+ b = [.5, .5]
+ M = [[0., 1.], [1., 0.]]
+ G = ot.sinkhorn(a, b, M, 1)
+ np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]]))
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_dist():
+
+ rng = np.random.RandomState(0)
+
+ for n_samples in [50, 100, 500, 1000]:
+ print(n_samples)
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
+
+ M = ot.dist(a.copy(), b.copy())
+ M2 = ot.gpu.dist(a.copy(), b.copy())
+
+ np.testing.assert_allclose(M, M2, rtol=1e-10)
+
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False)
+
+ # check raise not implemented wrong metric
+ with pytest.raises(NotImplementedError):
+ M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False)
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_sinkhorn():
+
+ rng = np.random.RandomState(0)
+
+ for n_samples in [50, 100, 500, 1000]:
+ a = rng.rand(n_samples // 4, 100)
+ b = rng.rand(n_samples, 100)
+
+ wa = ot.unif(n_samples // 4)
+ wb = ot.unif(n_samples)
+
+ wb2 = np.random.rand(n_samples, 20)
+ wb2 /= wb2.sum(0, keepdims=True)
+
+ M = ot.dist(a.copy(), b.copy())
+ M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
+
+ reg = 1
+
+ G = ot.sinkhorn(wa, wb, M, reg)
+ G1 = ot.gpu.sinkhorn(wa, wb, M, reg)
+
+ np.testing.assert_allclose(G1, G, rtol=1e-10)
+
+ # run all on gpu
+ ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True)
+
+ # run sinkhorn for multiple targets
+ ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True)
+
+
+@pytest.mark.skipif(nogpu, reason="No GPU available")
+def test_gpu_sinkhorn_lpl1():
+
+ rng = np.random.RandomState(0)
+
+ for n_samples in [50, 100, 500]:
+ print(n_samples)
+ a = rng.rand(n_samples // 4, 100)
+ labels_a = np.random.randint(10, size=(n_samples // 4))
+ b = rng.rand(n_samples, 100)
+
+ wa = ot.unif(n_samples // 4)
+ wb = ot.unif(n_samples)
+
+ M = ot.dist(a.copy(), b.copy())
+ M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
+
+ reg = 1
+
+ G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
+ G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
+
+ np.testing.assert_allclose(G1, G, rtol=1e-10)
+
+ ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True)
diff --git a/test/test_gromov.py b/test/test_gromov.py
new file mode 100644
index 0000000..70fa83f
--- /dev/null
+++ b/test/test_gromov.py
@@ -0,0 +1,246 @@
+"""Tests for module gromov """
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04)
+
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_entropic_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_barycenter():
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3,
+ verbose=True)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_gromov_entropic_barycenter():
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', 2e-3,
+ max_iter=100, tol=1e-3,
+ verbose=True)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', 2e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_fgw():
+
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw_barycenter():
+ np.random.seed(42)
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ ys = np.random.randn(Xs.shape[0], 2)
+ yt = np.random.randn(Xt.shape[0], 2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X = np.random.randn(n_samples, ys.shape[1])
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py
new file mode 100644
index 0000000..ae31e1f
--- /dev/null
+++ b/test/test_optim.py
@@ -0,0 +1,73 @@
+"""Tests for module optim fro OT optimization """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_conditional_gradient():
+
+ n_bins = 100 # nb bins
+ np.random.seed(0)
+ # bin positions
+ x = np.arange(n_bins, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n_bins, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n_bins, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n_bins, 1)), x.reshape((n_bins, 1)))
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg = 1e-1
+
+ G, log = ot.optim.cg(a, b, M, reg, f, df, verbose=True, log=True)
+
+ np.testing.assert_allclose(a, G.sum(1))
+ np.testing.assert_allclose(b, G.sum(0))
+
+
+def test_generalized_conditional_gradient():
+
+ n_bins = 100 # nb bins
+ np.random.seed(0)
+ # bin positions
+ x = np.arange(n_bins, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n_bins, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n_bins, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n_bins, 1)), x.reshape((n_bins, 1)))
+ M /= M.max()
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+ def df(G):
+ return G
+
+ reg1 = 1e-3
+ reg2 = 1e-1
+
+ G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
+
+ np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_solve_1d_linesearch_quad_funct():
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
diff --git a/test/test_ot.py b/test/test_ot.py
new file mode 100644
index 0000000..dacae0a
--- /dev/null
+++ b/test/test_ot.py
@@ -0,0 +1,308 @@
+"""Tests for main module ot """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import warnings
+
+import numpy as np
+from scipy.stats import wasserstein_distance
+
+import ot
+from ot.datasets import make_1D_gauss as gauss
+import pytest
+
+
+def test_emd_emd2():
+ # test emd and emd2 for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.emd(u, u, M)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
+
+ w = ot.emd2(u, u, M)
+ # check loss=0
+ np.testing.assert_allclose(w, 0)
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_wass_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+
+ wass1d = ot.wasserstein_1d(u, v, [], [], p=2.)
+
+ # check loss is similar
+ np.testing.assert_allclose(np.sqrt(wass), wass1d)
+
+
+def test_emd_empty():
+ # test emd and emd2 for simple identity
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.emd([], [], M)
+
+ # check G is identity
+ np.testing.assert_allclose(G, np.eye(n) / n)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
+ np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
+
+ w = ot.emd2([], [], M)
+ # check loss=0
+ np.testing.assert_allclose(w, 0)
+
+
+def test_emd2_multi():
+ n = 500 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+
+ ls = np.arange(20, 500, 20)
+ nb = len(ls)
+ b = np.zeros((n, nb))
+ for i in range(nb):
+ b[:, i] = gauss(n, m=ls[i], s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ # M/=M.max()
+
+ print('Computing {} EMD '.format(nb))
+
+ # emd loss 1 proc
+ ot.tic()
+ emd1 = ot.emd2(a, b, M, 1)
+ ot.toc('1 proc : {} s')
+
+ # emd loss multipro proc
+ ot.tic()
+ emdn = ot.emd2(a, b, M)
+ ot.toc('multi proc : {} s')
+
+ np.testing.assert_allclose(emd1, emdn)
+
+ # emd loss multipro proc with log
+ ot.tic()
+ emdn = ot.emd2(a, b, M, log=True, return_matrix=True)
+ ot.toc('multi proc : {} s')
+
+ for i in range(len(emdn)):
+ emd = emdn[i]
+ log = emd[1]
+ cost = emd[0]
+ check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost)
+ emdn[i] = cost
+
+ emdn = np.array(emdn)
+ np.testing.assert_allclose(emd1, emdn)
+
+
+def test_lp_barycenter():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5])
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
+def test_free_support_barycenter():
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # obvious barycenter location between two diracs
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
+@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
+def test_lp_barycenter_cvxopt():
+
+ a1 = np.array([1.0, 0, 0])[:, None]
+ a2 = np.array([0, 0, 1.0])[:, None]
+
+ A = np.hstack((a1, a2))
+ M = np.array([[0, 1.0, 4.0], [1.0, 0, 1.0], [4.0, 1.0, 0]])
+
+ # obvious barycenter between two diracs
+ bary0 = np.array([0, 1.0, 0])
+
+ bary = ot.lp.barycenter(A, M, [.5, .5], solver=None)
+
+ np.testing.assert_allclose(bary, bary0, rtol=1e-5, atol=1e-7)
+ np.testing.assert_allclose(bary.sum(), 1)
+
+
+def test_warnings():
+ n = 100 # nb bins
+ m = 100 # nb bins
+
+ mean1 = 30
+ mean2 = 50
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+ y = np.arange(m, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=mean1, s=5) # m= mean, s= std
+
+ b = gauss(m, m=mean2, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
+
+ print('Computing {} EMD '.format(1))
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ print('Computing {} EMD '.format(1))
+ ot.emd(a, b, M, numItermax=1)
+ assert "numItermax" in str(w[-1].message)
+ assert len(w) == 1
+ a[0] = 100
+ print('Computing {} EMD '.format(2))
+ ot.emd(a, b, M)
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 2
+ a[0] = -1
+ print('Computing {} EMD '.format(2))
+ ot.emd(a, b, M)
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 3
+
+
+def test_dual_variables():
+ n = 500 # nb bins
+ m = 600 # nb bins
+
+ mean1 = 300
+ mean2 = 400
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+ y = np.arange(m, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=mean1, s=5) # m= mean, s= std
+
+ b = gauss(m, m=mean2, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
+
+ print('Computing {} EMD '.format(1))
+
+ # emd loss 1 proc
+ ot.tic()
+ G, log = ot.emd(a, b, M, log=True)
+ ot.toc('1 proc : {} s')
+
+ ot.tic()
+ G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
+ ot.toc('1 proc : {} s')
+
+ cost1 = (G * M).sum()
+ # Check symmetry
+ np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum())
+ # Check with closed-form solution for gaussians
+ np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2))
+
+ # Check that both cost computations are equivalent
+ np.testing.assert_almost_equal(cost1, log['cost'])
+ check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
+
+
+def check_duality_gap(a, b, M, G, u, v, cost):
+ cost_dual = np.vdot(a, u) + np.vdot(b, v)
+ # Check that dual and primal cost are equal
+ np.testing.assert_almost_equal(cost_dual, cost)
+
+ [ind1, ind2] = np.nonzero(G)
+
+ # Check that reduced cost is zero on transport arcs
+ np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2],
+ np.zeros(ind1.size))
diff --git a/test/test_plot.py b/test/test_plot.py
new file mode 100644
index 0000000..caf84de
--- /dev/null
+++ b/test/test_plot.py
@@ -0,0 +1,60 @@
+"""Tests for module plot for visualization """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+
+try: # test if matplotlib is installed
+ import matplotlib
+ matplotlib.use('Agg')
+ nogo = False
+except ImportError:
+ nogo = True
+
+
+@pytest.mark.skipif(nogo, reason="Matplotlib not installed")
+def test_plot1D_mat():
+
+ import ot
+ import ot.plot
+
+ n_bins = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n_bins, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n_bins, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n_bins, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n_bins, 1)), x.reshape((n_bins, 1)))
+ M /= M.max()
+
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+@pytest.mark.skipif(nogo, reason="Matplotlib not installed")
+def test_plot2D_samples_mat():
+
+ import ot
+ import ot.plot
+
+ n_bins = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_bins, mu_s, cov_s)
+ xt = ot.datasets.make_2D_samples_gauss(n_bins, mu_t, cov_t)
+
+ G = 1.0 * (np.random.rand(n_bins, n_bins) < 0.01)
+
+ ot.plot.plot2D_samples_mat(xs, xt, G, thr=1e-5)
diff --git a/test/test_smooth.py b/test/test_smooth.py
new file mode 100644
index 0000000..2afa4f8
--- /dev/null
+++ b/test/test_smooth.py
@@ -0,0 +1,79 @@
+"""Tests for ot.smooth model """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+
+
+def test_smooth_ot_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ with pytest.raises(NotImplementedError):
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none')
+
+ Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)
+
+
+def test_smooth_ot_semi_dual():
+
+ # get data
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ with pytest.raises(NotImplementedError):
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none')
+
+ Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ # kl regyularisation
+ G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+
+ G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ np.testing.assert_allclose(G, G2, atol=1e-05)
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
new file mode 100644
index 0000000..f0f3fc8
--- /dev/null
+++ b/test/test_stochastic.py
@@ -0,0 +1,215 @@
+"""
+==========================
+Stochastic test
+==========================
+
+This example is designed to test the stochatic optimization algorithms module
+for descrete and semicontinous measures from the POT library.
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+#############################################################################
+# COMPUTE TEST FOR SEMI-DUAL PROBLEM
+#############################################################################
+
+#############################################################################
+#
+# TEST SAG algorithm
+# ---------------------------------------------
+# 2 identical discrete measures u defined on the same space with a
+# regularization term, a learning rate and a number of iteration
+
+
+def test_stochastic_sag():
+ # test sag
+ n = 15
+ reg = 1
+ numItermax = 30000
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
+ numItermax=numItermax)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-04) # cf convergence sag
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-04) # cf convergence sag
+
+
+#############################################################################
+#
+# TEST ASGD algorithm
+# ---------------------------------------------
+# 2 identical discrete measures u defined on the same space with a
+# regularization term, a learning rate and a number of iteration
+
+
+def test_stochastic_asgd():
+ # test asgd
+ n = 15
+ reg = 1
+ numItermax = 100000
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
+ numItermax=numItermax)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-03) # cf convergence asgd
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-03) # cf convergence asgd
+
+
+#############################################################################
+#
+# TEST Convergence SAG and ASGD toward Sinkhorn's solution
+# --------------------------------------------------------
+# 2 identical discrete measures u defined on the same space with a
+# regularization term, a learning rate and a number of iteration
+
+
+def test_sag_asgd_sinkhorn():
+ # test all algorithms
+ n = 15
+ reg = 1
+ nb_iter = 100000
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
+ numItermax=nb_iter)
+ G_sag = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
+ numItermax=nb_iter)
+ G_sinkhorn = ot.sinkhorn(u, u, M, reg)
+
+ # check constratints
+ np.testing.assert_allclose(
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
+ np.testing.assert_allclose(
+ G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd
+
+
+#############################################################################
+# COMPUTE TEST FOR DUAL PROBLEM
+#############################################################################
+
+#############################################################################
+#
+# TEST SGD algorithm
+# ---------------------------------------------
+# 2 identical discrete measures u defined on the same space with a
+# regularization term, a batch_size and a number of iteration
+
+
+def test_stochastic_dual_sgd():
+ # test sgd
+ n = 10
+ reg = 1
+ numItermax = 15000
+ batch_size = 10
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
+ numItermax=numItermax)
+
+ # check constratints
+ np.testing.assert_allclose(
+ u, G.sum(1), atol=1e-03) # cf convergence sgd
+ np.testing.assert_allclose(
+ u, G.sum(0), atol=1e-03) # cf convergence sgd
+
+
+#############################################################################
+#
+# TEST Convergence SGD toward Sinkhorn's solution
+# --------------------------------------------------------
+# 2 identical discrete measures u defined on the same space with a
+# regularization term, a batch_size and a number of iteration
+
+
+def test_dual_sgd_sinkhorn():
+ # test all dual algorithms
+ n = 10
+ reg = 1
+ nb_iter = 15000
+ batch_size = 10
+ rng = np.random.RandomState(0)
+
+# Test uniform
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ M = ot.dist(x, x)
+
+ G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
+ numItermax=nb_iter)
+
+ G_sinkhorn = ot.sinkhorn(u, u, M, reg)
+
+ # check constratints
+ np.testing.assert_allclose(
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+
+# Test gaussian
+ n = 30
+ reg = 1
+ batch_size = 30
+
+ a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n, 15, 5)
+ X_source = np.arange(n, dtype=np.float64)
+ Y_target = np.arange(n, dtype=np.float64)
+ M = ot.dist(X_source.reshape((n, 1)), Y_target.reshape((n, 1)))
+ M /= M.max()
+
+ G_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size,
+ numItermax=nb_iter)
+
+ G_sinkhorn = ot.sinkhorn(a, b, M, reg)
+
+ # check constratints
+ np.testing.assert_allclose(
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ np.testing.assert_allclose(
+ G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
new file mode 100644
index 0000000..ca1efba
--- /dev/null
+++ b/test/test_unbalanced.py
@@ -0,0 +1,221 @@
+"""Tests for module Unbalanced OT with entropy regularization"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+import pytest
+from ot.unbalanced import barycenter_unbalanced
+
+from scipy.special import logsumexp
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_convergence(method):
+ # test generalized sinkhorn for unbalanced OT
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
+ log=True)
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ # check fixed point equations
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+ # check if sinkhorn_unbalanced2 returns the correct loss
+ np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_multiple_inputs(method):
+ # test generalized sinkhorn for unbalanced OT
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = rng.rand(n, 2)
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
+ log=True)
+ # check fixed point equations
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+ assert len(loss) == b.shape[1]
+
+
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 2])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
+ # check fixed point equations
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
+
+ np.testing.assert_allclose(
+ u_final, log["logu"], atol=1e-05)
+ np.testing.assert_allclose(
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+ for method in IMPLEMENTED_METHODS:
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
+ with pytest.warns(UserWarning, match='not implemented'):
+ for method in set(TO_BE_IMPLEMENTED_METHODS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
+ method=method)
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
diff --git a/test/test_utils.py b/test/test_utils.py
new file mode 100644
index 0000000..640598d
--- /dev/null
+++ b/test/test_utils.py
@@ -0,0 +1,202 @@
+"""Tests for module utils for timing and parallel computation """
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+
+import ot
+import numpy as np
+import sys
+
+
+def test_parmap():
+
+ n = 10
+
+ def f(i):
+ return 1.0 * i * i
+
+ a = np.arange(n)
+
+ l1 = list(map(f, a))
+
+ l2 = list(ot.utils.parmap(f, a))
+
+ np.testing.assert_allclose(l1, l2)
+
+
+def test_tic_toc():
+
+ import time
+
+ ot.tic()
+ time.sleep(0.5)
+ t = ot.toc()
+ t2 = ot.toq()
+
+ # test timing
+ np.testing.assert_allclose(0.5, t, rtol=1e-2, atol=1e-2)
+
+ # test toc vs toq
+ np.testing.assert_allclose(t, t2, rtol=1e-2, atol=1e-2)
+
+
+def test_kernel():
+
+ n = 100
+
+ x = np.random.randn(n, 2)
+
+ K = ot.utils.kernel(x, x)
+
+ # gaussian kernel has ones on the diagonal
+ np.testing.assert_allclose(np.diag(K), np.ones(n))
+
+
+def test_unif():
+
+ n = 100
+
+ u = ot.unif(n)
+
+ np.testing.assert_allclose(1, np.sum(u))
+
+
+def test_dist():
+
+ n = 100
+
+ x = np.random.randn(n, 2)
+
+ D = np.zeros((n, n))
+ for i in range(n):
+ for j in range(n):
+ D[i, j] = np.sum(np.square(x[i, :] - x[j, :]))
+
+ D2 = ot.dist(x, x)
+ D3 = ot.dist(x)
+
+ # dist shoul return squared euclidean
+ np.testing.assert_allclose(D, D2)
+ np.testing.assert_allclose(D, D3)
+
+
+def test_dist0():
+
+ n = 100
+ M = ot.utils.dist0(n, method='lin_square')
+
+ # dist0 default to linear sampling with quadratic loss
+ np.testing.assert_allclose(M[0, -1], (n - 1) * (n - 1))
+
+
+def test_dots():
+
+ n1, n2, n3, n4 = 100, 50, 200, 100
+
+ A = np.random.randn(n1, n2)
+ B = np.random.randn(n2, n3)
+ C = np.random.randn(n3, n4)
+
+ X1 = ot.utils.dots(A, B, C)
+
+ X2 = A.dot(B.dot(C))
+
+ np.testing.assert_allclose(X1, X2)
+
+
+def test_clean_zeros():
+
+ n = 100
+ nz = 50
+ nz2 = 20
+ u1 = ot.unif(n)
+ u1[:nz] = 0
+ u1 = u1 / u1.sum()
+ u2 = ot.unif(n)
+ u2[:nz2] = 0
+ u2 = u2 / u2.sum()
+
+ M = ot.utils.dist0(n)
+
+ a, b, M2 = ot.utils.clean_zeros(u1, u2, M)
+
+ assert len(a) == n - nz
+ assert len(b) == n - nz2
+
+
+def test_cost_normalization():
+
+ C = np.random.rand(10, 10)
+
+ # does nothing
+ M0 = ot.utils.cost_normalization(C)
+ np.testing.assert_allclose(C, M0)
+
+ M = ot.utils.cost_normalization(C, 'median')
+ np.testing.assert_allclose(np.median(M), 1)
+
+ M = ot.utils.cost_normalization(C, 'max')
+ np.testing.assert_allclose(M.max(), 1)
+
+ M = ot.utils.cost_normalization(C, 'log')
+ np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+
+ M = ot.utils.cost_normalization(C, 'loglog')
+ np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+
+
+def test_check_params():
+
+ res1 = ot.utils.check_params(first='OK', second=20)
+ assert res1 is True
+
+ res0 = ot.utils.check_params(first='OK', second=None)
+ assert res0 is False
+
+
+def test_deprecated_func():
+
+ @ot.utils.deprecated('deprecated text for fun')
+ def fun():
+ pass
+
+ def fun2():
+ pass
+
+ @ot.utils.deprecated('deprecated text for class')
+ class Class():
+ pass
+
+ if sys.version_info < (3, 5):
+ print('Not tested')
+ else:
+ assert ot.utils._is_deprecated(fun) is True
+
+ assert ot.utils._is_deprecated(fun2) is False
+
+
+def test_BaseEstimator():
+
+ class Class(ot.utils.BaseEstimator):
+
+ def __init__(self, first='spam', second='eggs'):
+
+ self.first = first
+ self.second = second
+
+ cl = Class()
+
+ names = cl._get_param_names()
+ assert 'first' in names
+ assert 'second' in names
+
+ params = cl.get_params()
+ assert 'first' in params
+ assert 'second' in params
+
+ params['first'] = 'spam again'
+ cl.set_params(**params)
+
+ assert cl.first == 'spam again'