diff options
Diffstat (limited to 'examples/da/plot_otda_mapping.py')
-rw-r--r-- | examples/da/plot_otda_mapping.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/examples/da/plot_otda_mapping.py b/examples/da/plot_otda_mapping.py index aea7f09..09d2cb4 100644 --- a/examples/da/plot_otda_mapping.py +++ b/examples/da/plot_otda_mapping.py @@ -23,25 +23,31 @@ import matplotlib.pylab as pl import ot -np.random.seed(42) - ############################################################################## -# generate +# generate data ############################################################################## -n = 100 # nb samples in source and target datasets +n_source_samples = 100 +n_target_samples = 100 theta = 2 * np.pi / 20 noise_level = 0.1 -Xs, ys = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level) -Xs_new, _ = ot.datasets.get_data_classif('gaussrot', n, nz=noise_level) + +Xs, ys = ot.datasets.get_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xs_new, _ = ot.datasets.get_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) Xt, yt = ot.datasets.get_data_classif( - 'gaussrot', n, theta=theta, nz=noise_level) + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) # one of the target mode changes its variance (no linear mapping) Xt[yt == 2] *= 3 Xt = Xt + 4 +############################################################################## +# Instantiate the different transport algorithms and fit them +############################################################################## + # MappingTransport with linear kernel ot_mapping_linear = ot.da.MappingTransport( kernel="linear", mu=1e0, eta=1e-8, bias=True, @@ -80,6 +86,7 @@ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') pl.legend(loc=0) pl.title('Source and target distributions') + ############################################################################## # plot transported samples ############################################################################## |