summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:50:11 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:50:11 +0200
commit2c9f992157844d6253a302905417e86580ac6b12 (patch)
treed7314b752b8d50f7bcc8e34eab11e2a109ddb222 /examples
parent34e13d467e376e9bfee2eb15771d9308518c2adb (diff)
upd
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_otda_classes.py1
-rw-r--r--examples/plot_otda_jcpot.py16
2 files changed, 8 insertions, 9 deletions
diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py
index c311fbd..f028022 100644
--- a/examples/plot_otda_classes.py
+++ b/examples/plot_otda_classes.py
@@ -17,7 +17,6 @@ approaches currently supported in POT.
import matplotlib.pylab as pl
import ot
-
##############################################################################
# Generate data
# -------------
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index ce6b88f..316fa8b 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -118,16 +118,16 @@ pl.axis('off')
otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
otda.fit(all_Xr, all_Yr, xt)
-ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2'])
-ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2'])
+ws1 = otda.proportions_.dot(otda.log_['D2'][0])
+ws2 = otda.proportions_.dot(otda.log_['D2'][1])
pl.figure(3)
pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -146,16 +146,16 @@ pl.axis('off')
# ----------------------------------------------------------------------------
h_res = np.array([1 - pt, pt])
-ws1 = h_res.dot(otda.log_['all_domains'][0]['D2'])
-ws2 = h_res.dot(otda.log_['all_domains'][1]['D2'])
+ws1 = h_res.dot(otda.log_['D2'][0])
+ws2 = h_res.dot(otda.log_['D2'][1])
pl.figure(4)
pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)