summaryrefslogtreecommitdiff
path: root/examples/plot_otda_jcpot.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_otda_jcpot.py')
-rw-r--r--examples/plot_otda_jcpot.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index 316fa8b..c495690 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -115,7 +115,7 @@ pl.axis('off')
##############################################################################
# Instantiate JCPOT adaptation algorithm and fit it
# ----------------------------------------------------------------------------
-otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
+otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
otda.fit(all_Xr, all_Yr, xt)
ws1 = otda.proportions_.dot(otda.log_['D2'][0])
@@ -126,8 +126,8 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -154,8 +154,8 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)