diff options
author | Slasnista <stan.chambon@gmail.com> | 2017-08-29 13:34:34 +0200 |
---|---|---|
committer | Slasnista <stan.chambon@gmail.com> | 2017-08-29 13:34:34 +0200 |
commit | 65de6fc9add57b95b8968e1e75fe1af342f81d01 (patch) | |
tree | 1b973cb5314af46a28060c477840903bf3fbf4ac /examples/da/plot_otda_mapping.py | |
parent | a29e22db4772ebc4a8266c917e2e662f624c6baa (diff) |
pass on examples | introduced RandomState
Diffstat (limited to 'examples/da/plot_otda_mapping.py')
-rw-r--r-- | examples/da/plot_otda_mapping.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/examples/da/plot_otda_mapping.py b/examples/da/plot_otda_mapping.py index aea7f09..09d2cb4 100644 --- a/examples/da/plot_otda_mapping.py +++ b/examples/da/plot_otda_mapping.py @@ -23,25 +23,31 @@ import matplotlib.pylab as pl import ot -np.random.seed(42) - ############################################################################## -# generate +# generate data ############################################################################## -n = 100 # nb samples in source and target datasets +n_source_samples = 100 +n_target_samples = 100 theta = 2 * np.pi / 20 noise_level = 0.1 -Xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level) -Xs_new, _ = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level) + +Xs, ys = ot.datasets.get_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xs_new, _ = ot.datasets.get_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) Xt, yt = ot.datasets.get_data_classif( - 'gaussrot', n, theta=theta, nz=noise_level) + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) # one of the target mode changes its variance (no linear mapping) Xt[yt == 2] *= 3 Xt = Xt + 4 +############################################################################## +# Instantiate the different transport algorithms and fit them +############################################################################## + # MappingTransport with linear kernel ot_mapping_linear = ot.da.MappingTransport( kernel="linear", mu=1e0, eta=1e-8, bias=True, @@ -80,6 +86,7 @@ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.legend(loc=0) pl.title('Source and target distributions') + ############################################################################## # plot transported samples ############################################################################## |