From 547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 1 Apr 2020 09:13:58 +0200 Subject: fix test example add M to log --- test/test_da.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index 7526f30..a13550c 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -568,7 +568,7 @@ def test_jcpot_transport_class(): Xs = [Xs1, Xs2] ys = [ys1, ys2] - otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True) + otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log = True) # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -591,14 +591,8 @@ def test_jcpot_transport_class(): # test margin constraints w.r.t. modified source weights for each source domain - D1 = np.zeros((len(np.unique(ys[i])), len(ys[i]))) - for c in np.unique(ys[i]): - nbelemperclass = np.sum(ys[i] == c) - if nbelemperclass != 0: - D1[int(c), ys[i] == c] = 1. - assert_allclose( - np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) + np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -609,6 +603,3 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) - - -test_jcpot_transport_class() \ No newline at end of file -- cgit v1.2.3