summaryrefslogtreecommitdiff
path: root/examples/da/plot_otda_classes.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-29 13:34:34 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:09:13 +0200
commit5a9795f08341458bd9e3befe0c2c6ea6fa891323 (patch)
treea721b13aeac706d6de813a2147c5572d057d866c /examples/da/plot_otda_classes.py
parent3730779200896ee7de533eb7c5d7fa19e09eeb25 (diff)
pass on examples | introduced RandomState
Diffstat (limited to 'examples/da/plot_otda_classes.py')
-rw-r--r--examples/da/plot_otda_classes.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/examples/da/plot_otda_classes.py b/examples/da/plot_otda_classes.py
index 6870fa4..ec57a37 100644
--- a/examples/da/plot_otda_classes.py
+++ b/examples/da/plot_otda_classes.py
@@ -15,19 +15,23 @@ approaches currently supported in POT.
# License: MIT License
import matplotlib.pylab as pl
-import numpy as np
import ot
-np.random.seed(42)
-# number of source and target points to generate
-ns = 150
-nt = 150
+##############################################################################
+# generate data
+##############################################################################
+
+n_source_samples = 150
+n_target_samples = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)
-Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
-Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+##############################################################################
# Instantiate the different transport algorithms and fit them
+##############################################################################
# EMD Transport
ot_emd = ot.da.EMDTransport()
@@ -52,6 +56,7 @@ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
##############################################################################
# Fig 1 : plots source and target samples
##############################################################################
@@ -72,6 +77,7 @@ pl.legend(loc=0)
pl.title('Target samples')
pl.tight_layout()
+
##############################################################################
# Fig 2 : plot optimal couplings and transported samples
##############################################################################