diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-07 13:50:11 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-07 13:50:11 +0200 |
commit | 2c9f992157844d6253a302905417e86580ac6b12 (patch) | |
tree | d7314b752b8d50f7bcc8e34eab11e2a109ddb222 | |
parent | 34e13d467e376e9bfee2eb15771d9308518c2adb (diff) |
upd
-rw-r--r-- | examples/plot_otda_classes.py | 1 | ||||
-rw-r--r-- | examples/plot_otda_jcpot.py | 16 | ||||
-rw-r--r-- | ot/bregman.py | 2 | ||||
-rw-r--r-- | test/test_da.py | 2 |
4 files changed, 10 insertions, 11 deletions
diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py index c311fbd..f028022 100644 --- a/examples/plot_otda_classes.py +++ b/examples/plot_otda_classes.py @@ -17,7 +17,6 @@ approaches currently supported in POT. import matplotlib.pylab as pl import ot - ############################################################################## # Generate data # ------------- diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index ce6b88f..316fa8b 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -118,16 +118,16 @@ pl.axis('off') otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) -ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) -ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2']) +ws1 = otda.proportions_.dot(otda.log_['D2'][0]) +ws2 = otda.proportions_.dot(otda.log_['D2'][1]) pl.figure(3) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -146,16 +146,16 @@ pl.axis('off') # ---------------------------------------------------------------------------- h_res = np.array([1 - pt, pt]) -ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) -ws2 = h_res.dot(otda.log_['all_domains'][1]['D2']) +ws1 = h_res.dot(otda.log_['D2'][0]) +ws2 = h_res.dot(otda.log_['D2'][1]) pl.figure(4) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) diff --git a/ot/bregman.py b/ot/bregman.py index ec81924..61dfa52 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1608,7 +1608,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) Mtmp = Mtmp / np.median(Mtmp) - M.append(M) + M.append(Mtmp) Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) np.divide(Mtmp, -reg, out=Ktmp) diff --git a/test/test_da.py b/test/test_da.py index 372ebd4..4eaf193 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -589,7 +589,7 @@ def test_jcpot_transport_class(): # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( - np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, + np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) # test transform |