diff options
Diffstat (limited to 'examples/da/plot_otda_d2.py')
-rw-r--r-- | examples/da/plot_otda_d2.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/examples/da/plot_otda_d2.py b/examples/da/plot_otda_d2.py index 1d2192f..8833eb2 100644 --- a/examples/da/plot_otda_d2.py +++ b/examples/da/plot_otda_d2.py @@ -19,17 +19,19 @@ of what the transport methods are doing. # License: MIT License import matplotlib.pylab as pl +import numpy as np import ot -# number of source and target points to generate -ns = 150 -nt = 150 +np.random.seed(42) -Xs, ys = ot.datasets.get_data_classif('3gauss', ns) -Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) +n_samples_source = 150 +n_samples_target = 150 + +Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source) +Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target) # Cost matrix -M = ot.dist(Xs, Xt) +M = ot.dist(Xs, Xt, metric='sqeuclidean') # Instantiate the different transport algorithms and fit them |