From bc51793333994a1bf6263c9e9c111d754172fc82 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 8 Apr 2020 14:35:00 +0200 Subject: added test barycenter + modif target --- test/test_da.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'test/test_da.py') diff --git a/test/test_da.py b/test/test_da.py index b58cf51..c54dab7 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -601,3 +601,31 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + + +def test_jcpot_barycenter(): + """test_jcpot_barycenter + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + sigma = 0.1 + np.random.seed(1985) + + ps1 = .2 + ps2 = .9 + pt = .4 + + Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) + Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) + Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + _, prop, = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + numItermax=10000, stopThr=1e-9, verbose=False, log=False) + + np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) -- cgit v1.2.3