summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorngayraud <nat.gayraud@gmail.com>2019-08-12 16:37:58 -0400
committerngayraud <nat.gayraud@gmail.com>2019-08-12 16:37:58 -0400
commit9d4b786a036ac95989825beec819521089fb4feb (patch)
tree3d9fcc4fd26e5d8dbe100d79eddf0776801df33a /test/test_da.py
parent092866815cf906012f9194b87af1e7ae0270f7e7 (diff)
fixes for travis, added test, minor nits
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..9efd2d9 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,79 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test margin constraints
+ mu_s = unif(ns)
+ mu_t = unif(nt)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""