From 212f3889b1114026765cda0134e02766daa82af2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 31 Aug 2017 09:28:37 +0200 Subject: update tests --- examples/plot_OT_1D.py | 3 ++ examples/plot_OT_2D_samples.py | 15 +++++++ examples/plot_WDA.py | 27 ++++++++++++ examples/plot_barycenter_1D.py | 8 ++++ examples/plot_compute_emd.py | 26 ++++++++++-- examples/plot_optim_OTreg.py | 18 ++++++++ examples/plot_otda_color_images.py | 66 ++++++++++++++--------------- examples/plot_otda_mapping.py | 27 ++++++------ examples/plot_otda_mapping_colors_images.py | 15 ++++--- 9 files changed, 148 insertions(+), 57 deletions(-) diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 77114c4..a1473c4 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -4,6 +4,9 @@ 1D optimal transport ==================== +This example illustrate the computation of EMD and Sinkhorn transport plans +and their visualization. + """ # Author: Remi Flamary diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index 2a42dc0..a913b8c 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -14,6 +14,10 @@ import numpy as np import matplotlib.pylab as pl import ot +############################################################################## +# Generate data +############################################################################## + #%% parameters and data generation n = 50 # nb samples @@ -33,6 +37,10 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples M = ot.dist(xs, xt) M /= M.max() +############################################################################## +# Plot data +############################################################################## + #%% plot samples pl.figure(1) @@ -45,6 +53,9 @@ pl.figure(2) pl.imshow(M, interpolation='nearest') pl.title('Cost matrix M') +############################################################################## +# Compute EMD +############################################################################## #%% EMD @@ -62,6 +73,10 @@ pl.legend(loc=0) pl.title('OT matrix with samples') +############################################################################## +# Compute Sinkhorn +############################################################################## + #%% sinkhorn # reg term diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py index 42789f2..06a2e38 100644 --- a/examples/plot_WDA.py +++ b/examples/plot_WDA.py @@ -4,6 +4,12 @@ Wasserstein Discriminant Analysis ================================= +This example illustrate the use of WDA as proposed in [11]. + + +[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). +Wasserstein Discriminant Analysis. + """ # Author: Remi Flamary @@ -16,6 +22,10 @@ import matplotlib.pylab as pl from ot.dr import wda, fda +############################################################################## +# Generate data +############################################################################## + #%% parameters n = 1000 # nb samples in source and target datasets @@ -39,6 +49,10 @@ nbnoise = 8 xs = np.hstack((xs, np.random.randn(n, nbnoise))) xt = np.hstack((xt, np.random.randn(n, nbnoise))) +############################################################################## +# Plot data +############################################################################## + #%% plot samples pl.figure(1, figsize=(6.4, 3.5)) @@ -53,11 +67,19 @@ pl.legend(loc=0) pl.title('Other dimensions') pl.tight_layout() +############################################################################## +# Compute Fisher Discriminant Analysis +############################################################################## + #%% Compute FDA p = 2 Pfda, projfda = fda(xs, ys, p) +############################################################################## +# Compute Wasserstein Discriminant Analysis +############################################################################## + #%% Compute WDA p = 2 reg = 1e0 @@ -66,6 +88,11 @@ maxiter = 100 Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) + +############################################################################## +# Plot 2D projections +############################################################################## + #%% plot samples xsp = projfda(xs) diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py index 875f44c..f3be247 100644 --- a/examples/plot_barycenter_1D.py +++ b/examples/plot_barycenter_1D.py @@ -4,6 +4,14 @@ 1D Wasserstein barycenter demo ============================== +This example illustrate the computation of regularized Wassersyein Barycenter +as proposed in [3]. + + +[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). +Iterative Bregman projections for regularized transportation problems +SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + """ # Author: Remi Flamary diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 893eecf..704da0e 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +================= +Plot multiple EMD +================= """ @@ -16,6 +16,10 @@ import ot from ot.datasets import get_1D_gauss as gauss +############################################################################## +# Generate data +############################################################################## + #%% parameters n = 100 # nb bins @@ -40,6 +44,11 @@ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') M /= M.max() M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') M2 /= M2.max() + +############################################################################## +# Plot data +############################################################################## + #%% plot the distributions pl.figure(1) @@ -51,10 +60,15 @@ pl.plot(x, B, label='Target distributions') pl.title('Target distributions') pl.tight_layout() + +############################################################################## +# Compute EMD for the different losses +############################################################################## + #%% Compute and plot distributions and loss matrix d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M3 +d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 pl.figure(2) @@ -63,6 +77,10 @@ pl.plot(d_emd2, label='Squared Euclidean EMD') pl.title('EMD distances') pl.legend() +############################################################################## +# Compute Sinkhorn for the different losses +############################################################################## + #%% reg = 1e-2 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 7ef6a6b..95bcdaf 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -4,6 +4,24 @@ Regularized OT with generic solver ================================== +This example illustrate the use of the generic solver for regularized OT with +user designed regularization term. It uses Conditional gradient as in [6] and +generalized Conditional Gradient as proposed in [5][7]. + + +[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for +Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine +Intelligence , vol.PP, no.99, pp.1-1. + +[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). +Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, +7(3), 1853-1882. + +[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized +conditional gradient: analysis of convergence and applications. +arXiv preprint arXiv:1510.06567. + + """ diff --git a/examples/plot_otda_color_images.py b/examples/plot_otda_color_images.py index 46ad44b..f1df9d9 100644 --- a/examples/plot_otda_color_images.py +++ b/examples/plot_otda_color_images.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -======================================================== -OT for domain adaptation with image color adaptation [6] -======================================================== +============================= +OT for image color adaptation +============================= This example presents a way of transferring colors between two image with Optimal Transport as introduced in [6] @@ -41,7 +41,7 @@ def minmax(I): ############################################################################## -# generate data +# Generate data ############################################################################## # Loading images @@ -61,33 +61,7 @@ Xt = X2[idx2, :] ############################################################################## -# Instantiate the different transport algorithms and fit them -############################################################################## - -# EMDTransport -ot_emd = ot.da.EMDTransport() -ot_emd.fit(Xs=Xs, Xt=Xt) - -# SinkhornTransport -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) -ot_sinkhorn.fit(Xs=Xs, Xt=Xt) - -# prediction between images (using out of sample prediction as in [6]) -transp_Xs_emd = ot_emd.transform(Xs=X1) -transp_Xt_emd = ot_emd.inverse_transform(Xt=X2) - -transp_Xs_sinkhorn = ot_emd.transform(Xs=X1) -transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2) - -I1t = minmax(mat2im(transp_Xs_emd, I1.shape)) -I2t = minmax(mat2im(transp_Xt_emd, I2.shape)) - -I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape)) -I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) - - -############################################################################## -# plot original image +# Plot original image ############################################################################## pl.figure(1, figsize=(6.4, 3)) @@ -104,7 +78,7 @@ pl.title('Image 2') ############################################################################## -# scatter plot of colors +# Scatter plot of colors ############################################################################## pl.figure(2, figsize=(6.4, 3)) @@ -126,7 +100,33 @@ pl.tight_layout() ############################################################################## -# plot new images +# Instantiate the different transport algorithms and fit them +############################################################################## + +# EMDTransport +ot_emd = ot.da.EMDTransport() +ot_emd.fit(Xs=Xs, Xt=Xt) + +# SinkhornTransport +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) +ot_sinkhorn.fit(Xs=Xs, Xt=Xt) + +# prediction between images (using out of sample prediction as in [6]) +transp_Xs_emd = ot_emd.transform(Xs=X1) +transp_Xt_emd = ot_emd.inverse_transform(Xt=X2) + +transp_Xs_sinkhorn = ot_emd.transform(Xs=X1) +transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2) + +I1t = minmax(mat2im(transp_Xs_emd, I1.shape)) +I2t = minmax(mat2im(transp_Xt_emd, I2.shape)) + +I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape)) +I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) + + +############################################################################## +# Plot new images ############################################################################## pl.figure(3, figsize=(8, 4)) diff --git a/examples/plot_otda_mapping.py b/examples/plot_otda_mapping.py index 09d2cb4..e0da2d8 100644 --- a/examples/plot_otda_mapping.py +++ b/examples/plot_otda_mapping.py @@ -6,7 +6,7 @@ OT mapping estimation for domain adaptation [8] This example presents how to use MappingTransport to estimate at the same time both the coupling transport and approximate the transport map with either -a linear or a kernelized mapping as introduced in [8] +a linear or a kernelized mapping as introduced in [8]. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", @@ -43,6 +43,17 @@ Xt, yt = ot.datasets.get_data_classif( Xt[yt == 2] *= 3 Xt = Xt + 4 +############################################################################## +# plot data +############################################################################## + +pl.figure(1, (10, 5)) +pl.clf() +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + ############################################################################## # Instantiate the different transport algorithms and fit them @@ -76,19 +87,7 @@ transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new) ############################################################################## -# plot data -############################################################################## - -pl.figure(1, (10, 5)) -pl.clf() -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') -pl.legend(loc=0) -pl.title('Source and target distributions') - - -############################################################################## -# plot transported samples +# Plot transported samples ############################################################################## pl.figure(2) diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/plot_otda_mapping_colors_images.py index 936206c..a8b2ca8 100644 --- a/examples/plot_otda_mapping_colors_images.py +++ b/examples/plot_otda_mapping_colors_images.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- """ -==================================================================================== -OT for domain adaptation with image color adaptation [6] with mapping estimation [8] -==================================================================================== +=============================================== +OT for color adaptation with mapping estimation +=============================================== + +OT for domain adaptation with image color adaptation [6] with mapping +estimation [8]. [6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), @@ -93,7 +96,7 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) ############################################################################## -# plot original images +# Plot original images ############################################################################## pl.figure(1, figsize=(6.4, 3)) @@ -110,7 +113,7 @@ pl.tight_layout() ############################################################################## -# plot pixel values distribution +# Plot pixel values distribution ############################################################################## pl.figure(2, figsize=(6.4, 5)) @@ -132,7 +135,7 @@ pl.tight_layout() ############################################################################## -# plot transformed images +# Plot transformed images ############################################################################## pl.figure(2, figsize=(10, 5)) -- cgit v1.2.3