diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-08 14:35:00 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-08 14:35:00 +0200 |
commit | bc51793333994a1bf6263c9e9c111d754172fc82 (patch) | |
tree | 6957aedef64317f02df1e3a7c0c02f4c9476f81a | |
parent | 08d0bf9961567c2366ab8735aa7082b3a5542f6c (diff) |
added test barycenter + modif target
-rw-r--r-- | ot/bregman.py | 2 | ||||
-rw-r--r-- | test/test_da.py | 28 |
2 files changed, 29 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 410ae85..c44c141 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1528,7 +1528,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- diff --git a/test/test_da.py b/test/test_da.py index b58cf51..c54dab7 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -601,3 +601,31 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + + +def test_jcpot_barycenter(): + """test_jcpot_barycenter + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + sigma = 0.1 + np.random.seed(1985) + + ps1 = .2 + ps2 = .9 + pt = .4 + + Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) + Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) + Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + _, prop, = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + numItermax=10000, stopThr=1e-9, verbose=False, log=False) + + np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) |