summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:13:58 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:13:58 +0200
commit547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 (patch)
treef6795752c32fd95879324fc59ab280d3cb0b2551 /examples
parent439860609df786a877383775dd901afe28480cc9 (diff)
fix test example add M to log
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_otda_jcpot.py20
1 files changed, 6 insertions, 14 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index 5e5fff8..1641fb0 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -81,11 +81,7 @@ pl.axis('off')
##############################################################################
# Instantiate Sinkhorn transport algorithm and fit them for all source domains
# ----------------------------------------------------------------------------
-ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-2, metric='euclidean')
-
-M1 = ot.dist(xs1, xt, 'euclidean')
-M2 = ot.dist(xs2, xt, 'euclidean')
-
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
def print_G(G, xs, ys, xt):
for i in range(G.shape[0]):
@@ -125,7 +121,7 @@ pl.axis('off')
##############################################################################
# Instantiate JCPOT adaptation algorithm and fit it
# ----------------------------------------------------------------------------
-otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, tol=1e-9, verbose=True, log=True)
+otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
otda.fit(all_Xr, all_Yr, xt)
ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2'])
@@ -136,8 +132,8 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -154,10 +150,6 @@ pl.axis('off')
##############################################################################
# Run oracle transport algorithm with known proportions
# ----------------------------------------------------------------------------
-
-otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True)
-otda.fit(all_Xr, all_Yr, xt)
-
h_res = np.array([1 - pt, pt])
ws1 = h_res.dot(otda.log_['all_domains'][0]['D2'])
@@ -168,8 +160,8 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)