From 6aa0f1f4e275098948d4b312530119e5d95b8884 Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 31 Mar 2020 17:12:28 +0200 Subject: v1 jcpot example test --- test/test_da.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_da.py b/test/test_da.py index 2a5e50e..a8c258a 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -5,7 +5,7 @@ # License: MIT License import numpy as np -from numpy.testing.utils import assert_allclose, assert_equal +from numpy.testing import assert_allclose, assert_equal import ot from ot.datasets import make_data_classif @@ -549,3 +549,64 @@ def test_linear_mapping_class(): Cst = np.cov(Xst.T) np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_jcpot_transport_class(): + """test_jcpot_transport + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + Xs1, ys1 = make_data_classif('3gauss', ns1) + Xs2, ys2 = make_data_classif('3gauss', ns2) + + Xt, yt = make_data_classif('3gauss2', nt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + print(otda.proportions_) + + assert hasattr(otda, "coupling_") + assert hasattr(otda, "proportions_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + for i, xs in enumerate(Xs): + assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0]))) + + # test all margin constraints + mu_t = unif(nt) + + for i in range(len(Xs)): + # test margin constraints w.r.t. uniform target weights for each coupling matrix + assert_allclose( + np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + + # test margin constraints w.r.t. modified source weights for each source domain + + D1 = np.zeros((len(np.unique(ys[i])), len(ys[i]))) + for c in np.unique(ys[i]): + nbelemperclass = np.sum(ys[i] == c) + if nbelemperclass != 0: + D1[int(c), ys[i] == c] = 1. + + assert_allclose( + np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + #assert_equal(transp_Xs.shape, Xs.shape) + + Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) -- cgit v1.2.3