summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-03 17:29:13 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-03 17:29:13 +0200
commit98b68f1edc916d3802eeb24a19d0e10d855e01c6 (patch)
tree8830ef44936292de0d048d3c25170f180e843e7f /examples
parentfa99199c02e497354e34c6ce76e7b4ba15b44d05 (diff)
autopep+remove sinkhorn+add simtype
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_otda_laplacian.py38
1 files changed, 8 insertions, 30 deletions
diff --git a/examples/plot_otda_laplacian.py b/examples/plot_otda_laplacian.py
index d9ae280..965380c 100644
--- a/examples/plot_otda_laplacian.py
+++ b/examples/plot_otda_laplacian.py
@@ -5,7 +5,7 @@ OT for domain adaptation
========================
This example introduces a domain adaptation in a 2D setting and OTDA
-approaches with Laplacian regularization.
+approache with Laplacian regularization.
"""
@@ -36,22 +36,17 @@ ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)
# Sinkhorn Transport
-ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.5)
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
# EMD Transport with Laplacian regularization
ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
ot_emd_laplace.fit(Xs=Xs, Xt=Xt)
-# Sinkhorn Transport with Laplacian regularization
-ot_sinkhorn_laplace = ot.da.SinkhornLaplaceTransport(reg_e=.5, reg_lap=100, reg_src=1)
-ot_sinkhorn_laplace.fit(Xs=Xs, Xt=Xt)
-
# transport source samples onto target samples
transp_Xs_emd = ot_emd.transform(Xs=Xs)
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)
-transp_Xs_sinkhorn_laplace = ot_sinkhorn_laplace.transform(Xs=Xs)
##############################################################################
# Fig 1 : plots source and target samples
@@ -80,35 +75,27 @@ pl.tight_layout()
param_img = {'interpolation': 'nearest'}
-n_plots = 2
-
pl.figure(2, figsize=(15, 8))
-pl.subplot(2, 2*n_plots, 1)
+pl.subplot(2, 3, 1)
pl.imshow(ot_emd.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nEMDTransport')
pl.figure(2, figsize=(15, 8))
-pl.subplot(2, 2*n_plots, 2)
+pl.subplot(2, 3, 2)
pl.imshow(ot_sinkhorn.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nSinkhornTransport')
-pl.subplot(2, 2*n_plots, 3)
+pl.subplot(2, 3, 3)
pl.imshow(ot_emd_laplace.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nEMDLaplaceTransport')
-pl.subplot(2, 2*n_plots, 4)
-pl.imshow(ot_emd_laplace.coupling_, **param_img)
-pl.xticks([])
-pl.yticks([])
-pl.title('Optimal coupling\nSinkhornLaplaceTransport')
-
-pl.subplot(2, 2*n_plots, 5)
+pl.subplot(2, 3, 4)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
@@ -118,7 +105,7 @@ pl.yticks([])
pl.title('Transported samples\nEmdTransport')
pl.legend(loc="lower left")
-pl.subplot(2, 2*n_plots, 6)
+pl.subplot(2, 3, 5)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
@@ -127,7 +114,7 @@ pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nSinkhornTransport')
-pl.subplot(2, 2*n_plots, 7)
+pl.subplot(2, 3, 6)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
@@ -135,15 +122,6 @@ pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nEMDLaplaceTransport')
-
-pl.subplot(2, 2*n_plots, 8)
-pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
- label='Target samples', alpha=0.3)
-pl.scatter(transp_Xs_sinkhorn_laplace[:, 0], transp_Xs_sinkhorn_laplace[:, 1], c=ys,
- marker='+', label='Transp samples', s=30)
-pl.xticks([])
-pl.yticks([])
-pl.title('Transported samples\nSinkhornLaplaceTransport')
pl.tight_layout()
pl.show()