summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-08 14:35:00 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-08 14:35:00 +0200
commitbc51793333994a1bf6263c9e9c111d754172fc82 (patch)
tree6957aedef64317f02df1e3a7c0c02f4c9476f81a /test
parent08d0bf9961567c2366ab8735aa7082b3a5542f6c (diff)
added test barycenter + modif target
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index b58cf51..c54dab7 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -601,3 +601,31 @@ def test_jcpot_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+
+def test_jcpot_barycenter():
+ """test_jcpot_barycenter
+ """
+
+ ns1 = 150
+ ns2 = 150
+ nt = 200
+
+ sigma = 0.1
+ np.random.seed(1985)
+
+ ps1 = .2
+ ps2 = .9
+ pt = .4
+
+ Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1)
+ Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2)
+ Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
+
+ Xs = [Xs1, Xs2]
+ ys = [ys1, ys2]
+
+ _, prop, = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean',
+ numItermax=10000, stopThr=1e-9, verbose=False, log=False)
+
+ np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)