diff options
Diffstat (limited to 'examples/plot_otda_jcpot.py')
-rw-r--r-- | examples/plot_otda_jcpot.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index ce6b88f..316fa8b 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -118,16 +118,16 @@ pl.axis('off') otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) -ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) -ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2']) +ws1 = otda.proportions_.dot(otda.log_['D2'][0]) +ws2 = otda.proportions_.dot(otda.log_['D2'][1]) pl.figure(3) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -146,16 +146,16 @@ pl.axis('off') # ---------------------------------------------------------------------------- h_res = np.array([1 - pt, pt]) -ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) -ws2 = h_res.dot(otda.log_['all_domains'][1]['D2']) +ws1 = h_res.dot(otda.log_['D2'][0]) +ws2 = h_res.dot(otda.log_['D2'][1]) pl.figure(4) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) |