diff options
-rw-r--r-- | examples/plot_OT_1D.py | 10 | ||||
-rw-r--r-- | examples/plot_OT_2D_samples.py | 8 | ||||
-rw-r--r-- | examples/plot_OT_L1_vs_L2.py | 8 | ||||
-rw-r--r-- | examples/plot_WDA.py | 10 | ||||
-rw-r--r-- | examples/plot_barycenter_1D.py | 8 | ||||
-rw-r--r-- | examples/plot_compute_emd.py | 8 | ||||
-rw-r--r-- | examples/plot_optim_OTreg.py | 10 | ||||
-rw-r--r-- | examples/plot_otda_classes.py | 10 | ||||
-rw-r--r-- | examples/plot_otda_color_images.py | 10 | ||||
-rw-r--r-- | examples/plot_otda_d2.py | 17 | ||||
-rw-r--r-- | examples/plot_otda_mapping.py | 8 | ||||
-rw-r--r-- | examples/plot_otda_mapping_colors_images.py | 10 |
12 files changed, 59 insertions, 58 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index b6ffa5f..719058f 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -20,7 +20,7 @@ from ot.datasets import get_1D_gauss as gauss ############################################################################## # Generate data -# ############# +# ------------- #%% parameters @@ -41,7 +41,7 @@ M /= M.max() ############################################################################## # Plot distributions and loss matrix -################################### +# ---------------------------------- #%% plot the distributions @@ -57,7 +57,8 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') ############################################################################## # Solve EMD -############################################################################## +# --------- + #%% EMD @@ -68,7 +69,8 @@ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') ############################################################################## # Solve Sinkhorn -############################################################################## +# -------------- + #%% Sinkhorn diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index f57d631..9818ec5 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -19,7 +19,7 @@ import ot ############################################################################## # Generate data -############################################################################## +# ------------- #%% parameters and data generation @@ -42,7 +42,7 @@ M /= M.max() ############################################################################## # Plot data -############################################################################## +# --------- #%% plot samples @@ -58,7 +58,7 @@ pl.title('Cost matrix M') ############################################################################## # Compute EMD -############################################################################## +# ----------- #%% EMD @@ -78,7 +78,7 @@ pl.title('OT matrix with samples') ############################################################################## # Compute Sinkhorn -############################################################################## +# ---------------- #%% sinkhorn diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index 49d37e1..090e809 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -22,7 +22,7 @@ import ot ############################################################################## # Dataset 1 : uniform sampling -############################################################################## +# ---------------------------- n = 20 # nb samples xs = np.zeros((n, 2)) @@ -73,7 +73,7 @@ pl.tight_layout() ############################################################################## # Dataset 1 : Plot OT Matrices -############################################################################## +# ---------------------------- #%% EMD @@ -114,7 +114,7 @@ pl.show() ############################################################################## # Dataset 2 : Partial circle -############################################################################## +# -------------------------- n = 50 # nb samples xtot = np.zeros((n + 1, 2)) @@ -168,7 +168,7 @@ pl.tight_layout() ############################################################################## # Dataset 2 : Plot OT Matrices -############################################################################## +# ----------------------------- #%% EMD diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py index 5928621..93cc237 100644 --- a/examples/plot_WDA.py +++ b/examples/plot_WDA.py @@ -24,7 +24,7 @@ from ot.dr import wda, fda ############################################################################## # Generate data -############################################################################## +# ------------- #%% parameters @@ -51,7 +51,7 @@ xt = np.hstack((xt, np.random.randn(n, nbnoise))) ############################################################################## # Plot data -############################################################################## +# --------- #%% plot samples pl.figure(1, figsize=(6.4, 3.5)) @@ -69,7 +69,7 @@ pl.tight_layout() ############################################################################## # Compute Fisher Discriminant Analysis -############################################################################## +# ------------------------------------ #%% Compute FDA p = 2 @@ -78,7 +78,7 @@ Pfda, projfda = fda(xs, ys, p) ############################################################################## # Compute Wasserstein Discriminant Analysis -############################################################################## +# ----------------------------------------- #%% Compute WDA p = 2 @@ -91,7 +91,7 @@ Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) ############################################################################## # Plot 2D projections -############################################################################## +# ------------------- #%% plot samples diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py index eef8536..620936b 100644 --- a/examples/plot_barycenter_1D.py +++ b/examples/plot_barycenter_1D.py @@ -27,7 +27,7 @@ from matplotlib.collections import PolyCollection ############################################################################## # Generate data -############################################################################## +# ------------- #%% parameters @@ -50,7 +50,7 @@ M /= M.max() ############################################################################## # Plot data -############################################################################## +# --------- #%% plot the distributions @@ -62,7 +62,7 @@ pl.tight_layout() ############################################################################## # Barycenter computation -############################################################################## +# ---------------------- #%% barycenter computation @@ -92,7 +92,7 @@ pl.tight_layout() ############################################################################## # Barycentric interpolation -############################################################################## +# ------------------------- #%% barycenter interpolation diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index a84b249..73b42c3 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -22,7 +22,7 @@ from ot.datasets import get_1D_gauss as gauss ############################################################################## # Generate data -############################################################################## +# ------------- #%% parameters @@ -51,7 +51,7 @@ M2 /= M2.max() ############################################################################## # Plot data -############################################################################## +# --------- #%% plot the distributions @@ -67,7 +67,7 @@ pl.tight_layout() ############################################################################## # Compute EMD for the different losses -############################################################################## +# ------------------------------------ #%% Compute and plot distributions and loss matrix @@ -83,7 +83,7 @@ pl.legend() ############################################################################## # Compute Sinkhorn for the different losses -############################################################################## +# ----------------------------------------- #%% reg = 1e-2 diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index d753414..e1a737e 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -32,7 +32,7 @@ import ot ############################################################################## # Generate data -############################################################################## +# ------------- #%% parameters @@ -51,7 +51,7 @@ M /= M.max() ############################################################################## # Solve EMD -############################################################################## +# --------- #%% EMD @@ -62,7 +62,7 @@ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') ############################################################################## # Solve EMD with Frobenius norm regularization -############################################################################## +# -------------------------------------------- #%% Example with Frobenius norm regularization @@ -84,7 +84,7 @@ ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg') ############################################################################## # Solve EMD with entropic regularization -############################################################################## +# -------------------------------------- #%% Example with entropic regularization @@ -106,7 +106,7 @@ ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg') ############################################################################## # Solve EMD with Frobenius norm + entropic regularization -############################################################################## +# ------------------------------------------------------- #%% Example with Frobenius norm + entropic regularization with gcg diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py index ec57a37..b14c11a 100644 --- a/examples/plot_otda_classes.py +++ b/examples/plot_otda_classes.py @@ -19,8 +19,8 @@ import ot ############################################################################## -# generate data -############################################################################## +# Generate data +# ------------- n_source_samples = 150 n_target_samples = 150 @@ -31,7 +31,7 @@ Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples) ############################################################################## # Instantiate the different transport algorithms and fit them -############################################################################## +# ----------------------------------------------------------- # EMD Transport ot_emd = ot.da.EMDTransport() @@ -59,7 +59,7 @@ transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs) ############################################################################## # Fig 1 : plots source and target samples -############################################################################## +# --------------------------------------- pl.figure(1, figsize=(10, 5)) pl.subplot(1, 2, 1) @@ -80,7 +80,7 @@ pl.tight_layout() ############################################################################## # Fig 2 : plot optimal couplings and transported samples -############################################################################## +# ------------------------------------------------------ param_img = {'interpolation': 'nearest', 'cmap': 'spectral'} diff --git a/examples/plot_otda_color_images.py b/examples/plot_otda_color_images.py index f1df9d9..e77aec0 100644 --- a/examples/plot_otda_color_images.py +++ b/examples/plot_otda_color_images.py @@ -42,7 +42,7 @@ def minmax(I): ############################################################################## # Generate data -############################################################################## +# ------------- # Loading images I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 @@ -62,7 +62,7 @@ Xt = X2[idx2, :] ############################################################################## # Plot original image -############################################################################## +# ------------------- pl.figure(1, figsize=(6.4, 3)) @@ -79,7 +79,7 @@ pl.title('Image 2') ############################################################################## # Scatter plot of colors -############################################################################## +# ---------------------- pl.figure(2, figsize=(6.4, 3)) @@ -101,7 +101,7 @@ pl.tight_layout() ############################################################################## # Instantiate the different transport algorithms and fit them -############################################################################## +# ----------------------------------------------------------- # EMDTransport ot_emd = ot.da.EMDTransport() @@ -127,7 +127,7 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) ############################################################################## # Plot new images -############################################################################## +# --------------- pl.figure(3, figsize=(8, 4)) diff --git a/examples/plot_otda_d2.py b/examples/plot_otda_d2.py index 3daa0a6..e53d7d6 100644 --- a/examples/plot_otda_d2.py +++ b/examples/plot_otda_d2.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -============================== -OT for empirical distributions -============================== +=================================================== +OT for domain adaptation on empirical distributions +=================================================== This example introduces a domain adaptation in a 2D setting. It explicits the problem of domain adaptation and introduces some optimal transport @@ -24,7 +24,7 @@ import ot ############################################################################## # generate data -############################################################################## +# ------------- n_samples_source = 150 n_samples_target = 150 @@ -38,7 +38,7 @@ M = ot.dist(Xs, Xt, metric='sqeuclidean') ############################################################################## # Instantiate the different transport algorithms and fit them -############################################################################## +# ----------------------------------------------------------- # EMD Transport ot_emd = ot.da.EMDTransport() @@ -60,7 +60,7 @@ transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) ############################################################################## # Fig 1 : plots source and target samples + matrix of pairwise distance -############################################################################## +# --------------------------------------------------------------------- pl.figure(1, figsize=(10, 10)) pl.subplot(2, 2, 1) @@ -87,8 +87,7 @@ pl.tight_layout() ############################################################################## # Fig 2 : plots optimal couplings for the different methods -############################################################################## - +# --------------------------------------------------------- pl.figure(2, figsize=(10, 6)) pl.subplot(2, 3, 1) @@ -137,7 +136,7 @@ pl.tight_layout() ############################################################################## # Fig 3 : plot transported samples -############################################################################## +# -------------------------------- # display transported samples pl.figure(4, figsize=(10, 4)) diff --git a/examples/plot_otda_mapping.py b/examples/plot_otda_mapping.py index e78fef4..167c3a1 100644 --- a/examples/plot_otda_mapping.py +++ b/examples/plot_otda_mapping.py @@ -25,7 +25,7 @@ import ot ############################################################################## # Generate data -############################################################################## +# ------------- n_source_samples = 100 n_target_samples = 100 @@ -45,7 +45,7 @@ Xt = Xt + 4 ############################################################################## # Plot data -############################################################################## +# --------- pl.figure(1, (10, 5)) pl.clf() @@ -57,7 +57,7 @@ pl.title('Source and target distributions') ############################################################################## # Instantiate the different transport algorithms and fit them -############################################################################## +# ----------------------------------------------------------- # MappingTransport with linear kernel ot_mapping_linear = ot.da.MappingTransport( @@ -88,7 +88,7 @@ transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new) ############################################################################## # Plot transported samples -############################################################################## +# ------------------------ pl.figure(2) pl.clf() diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/plot_otda_mapping_colors_images.py index 5590286..5f1e844 100644 --- a/examples/plot_otda_mapping_colors_images.py +++ b/examples/plot_otda_mapping_colors_images.py @@ -45,7 +45,7 @@ def minmax(I): ############################################################################## # Generate data -############################################################################## +# ------------- # Loading images I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 @@ -66,7 +66,7 @@ Xt = X2[idx2, :] ############################################################################## # Domain adaptation for pixel distribution transfer -############################################################################## +# ------------------------------------------------- # EMDTransport ot_emd = ot.da.EMDTransport() @@ -97,7 +97,7 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) ############################################################################## # Plot original images -############################################################################## +# -------------------- pl.figure(1, figsize=(6.4, 3)) pl.subplot(1, 2, 1) @@ -114,7 +114,7 @@ pl.tight_layout() ############################################################################## # Plot pixel values distribution -############################################################################## +# ------------------------------ pl.figure(2, figsize=(6.4, 5)) @@ -136,7 +136,7 @@ pl.tight_layout() ############################################################################## # Plot transformed images -############################################################################## +# ----------------------- pl.figure(2, figsize=(10, 5)) |