From f439f777084690ecbf54bcd8d67dadc883fffa31 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Fri, 2 Dec 2016 12:49:38 +0100 Subject: first attempt to support sphinx-gallery --- examples/demo_OTDA_2D.py | 117 --------------------- examples/demo_OTDA_classes.py | 108 -------------------- examples/demo_OTDA_color_images.py | 143 -------------------------- examples/demo_OTDA_mapping.py | 113 --------------------- examples/demo_OTDA_mapping_color_images.py | 157 ---------------------------- examples/demo_OT_1D.py | 54 ---------- examples/demo_OT_2D_samples.py | 78 -------------- examples/demo_barycenter_1D.py | 135 ------------------------ examples/demo_optim_OTreg.py | 69 ------------- examples/plot_OTDA_2D.py | 120 ++++++++++++++++++++++ examples/plot_OTDA_classes.py | 111 ++++++++++++++++++++ examples/plot_OTDA_color_images.py | 145 ++++++++++++++++++++++++++ examples/plot_OTDA_mapping.py | 110 ++++++++++++++++++++ examples/plot_OTDA_mapping_color_images.py | 158 +++++++++++++++++++++++++++++ examples/plot_OT_1D.py | 56 ++++++++++ examples/plot_OT_2D_samples.py | 76 ++++++++++++++ examples/plot_barycenter_1D.py | 135 ++++++++++++++++++++++++ examples/plot_optim_OTreg.py | 69 +++++++++++++ 18 files changed, 980 insertions(+), 974 deletions(-) delete mode 100644 examples/demo_OTDA_2D.py delete mode 100644 examples/demo_OTDA_classes.py delete mode 100644 examples/demo_OTDA_color_images.py delete mode 100644 examples/demo_OTDA_mapping.py delete mode 100644 examples/demo_OTDA_mapping_color_images.py delete mode 100644 examples/demo_OT_1D.py delete mode 100644 examples/demo_OT_2D_samples.py delete mode 100644 examples/demo_barycenter_1D.py delete mode 100644 examples/demo_optim_OTreg.py create mode 100644 examples/plot_OTDA_2D.py create mode 100644 examples/plot_OTDA_classes.py create mode 100644 examples/plot_OTDA_color_images.py create mode 100644 examples/plot_OTDA_mapping.py create mode 100644 examples/plot_OTDA_mapping_color_images.py create mode 100644 examples/plot_OT_1D.py create mode 100644 examples/plot_OT_2D_samples.py create mode 100644 examples/plot_barycenter_1D.py create mode 100644 examples/plot_optim_OTreg.py (limited to 'examples') diff --git a/examples/demo_OTDA_2D.py b/examples/demo_OTDA_2D.py deleted file mode 100644 index fbaf56d..0000000 --- a/examples/demo_OTDA_2D.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -""" -demo of Optimal transport for domain adaptation -""" - -import numpy as np -import matplotlib.pylab as pl -import ot - - - -#%% parameters - -n=150 # nb bins - -xs,ys=ot.datasets.get_data_classif('3gauss',n) -xt,yt=ot.datasets.get_data_classif('3gauss2',n) - -a,b = ot.unif(n),ot.unif(n) -# loss matrix -M=ot.dist(xs,xt) -#M/=M.max() - -#%% plot samples - -pl.figure(1) - -pl.subplot(2,2,1) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.legend(loc=0) -pl.title('Source distributions') - -pl.subplot(2,2,2) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') -pl.legend(loc=0) -pl.title('target distributions') - -pl.figure(2) -pl.imshow(M,interpolation='nearest') -pl.title('Cost matrix M') - - -#%% OT estimation - -# EMD -G0=ot.emd(a,b,M) - -# sinkhorn -lambd=1e-1 -Gs=ot.sinkhorn(a,b,M,lambd) - - -# Group lasso regularization -reg=1e-1 -eta=1e0 -Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) - - -#%% visu matrices - -pl.figure(3) - -pl.subplot(2,3,1) -pl.imshow(G0,interpolation='nearest') -pl.title('OT matrix ') - -pl.subplot(2,3,2) -pl.imshow(Gs,interpolation='nearest') -pl.title('OT matrix Sinkhorn') - -pl.subplot(2,3,3) -pl.imshow(Gg,interpolation='nearest') -pl.title('OT matrix Group lasso') - -pl.subplot(2,3,4) -ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') - - -pl.subplot(2,3,5) -ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') - -pl.subplot(2,3,6) -ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') - -#%% sample interpolation - -xst0=n*G0.dot(xt) -xsts=n*Gs.dot(xt) -xstg=n*Gg.dot(xt) - -pl.figure(4) -pl.subplot(2,3,1) - - -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples') -pl.legend(loc=0) - -pl.subplot(2,3,2) - - -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) -pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Sinkhorn') - -pl.subplot(2,3,3) - -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) -pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Grouplasso') \ No newline at end of file diff --git a/examples/demo_OTDA_classes.py b/examples/demo_OTDA_classes.py deleted file mode 100644 index b2da7de..0000000 --- a/examples/demo_OTDA_classes.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -""" -demo of Optimal transport for domain adaptation -""" - -import matplotlib.pylab as pl -import ot - - - -#%% parameters - -n=150 # nb samples in source and target datasets - -xs,ys=ot.datasets.get_data_classif('3gauss',n) -xt,yt=ot.datasets.get_data_classif('3gauss2',n) - - - - -#%% plot samples - -pl.figure(1) - -pl.subplot(2,2,1) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.legend(loc=0) -pl.title('Source distributions') - -pl.subplot(2,2,2) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') -pl.legend(loc=0) -pl.title('target distributions') - - -#%% OT estimation - -# LP problem -da_emd=ot.da.OTDA() # init class -da_emd.fit(xs,xt) # fit distributions -xst0=da_emd.interp() # interpolation of source samples - - -# sinkhorn regularization -lambd=1e-1 -da_entrop=ot.da.OTDA_sinkhorn() -da_entrop.fit(xs,xt,reg=lambd) -xsts=da_entrop.interp() - -# non-convex Group lasso regularization -reg=1e-1 -eta=1e0 -da_lpl1=ot.da.OTDA_lpl1() -da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) -xstg=da_lpl1.interp() - - -# True Group lasso regularization -reg=1e-1 -eta=2e0 -da_l1l2=ot.da.OTDA_l1l2() -da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) -xstgl=da_l1l2.interp() - - -#%% plot interpolated source samples -pl.figure(4,(15,8)) - -param_img={'interpolation':'nearest','cmap':'jet'} - -pl.subplot(2,4,1) -pl.imshow(da_emd.G,**param_img) -pl.title('OT matrix') - - -pl.subplot(2,4,2) -pl.imshow(da_entrop.G,**param_img) -pl.title('OT matrix sinkhorn') - -pl.subplot(2,4,3) -pl.imshow(da_lpl1.G,**param_img) -pl.title('OT matrix non-convex Group Lasso') - -pl.subplot(2,4,4) -pl.imshow(da_l1l2.G,**param_img) -pl.title('OT matrix Group Lasso') - - -pl.subplot(2,4,5) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples') -pl.legend(loc=0) - -pl.subplot(2,4,6) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Sinkhorn') - -pl.subplot(2,4,7) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples non-convex Group Lasso') - -pl.subplot(2,4,8) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Group Lasso') \ No newline at end of file diff --git a/examples/demo_OTDA_color_images.py b/examples/demo_OTDA_color_images.py deleted file mode 100644 index 715707d..0000000 --- a/examples/demo_OTDA_color_images.py +++ /dev/null @@ -1,143 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Demo of Optimal transport for domain adaptation with image color adaptation as in [6] - -[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. -""" - -import numpy as np -import scipy.ndimage as spi -import matplotlib.pylab as pl -import ot - - -#%% Loading images - -I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 -I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 - -#%% Plot images - -pl.figure(1) - -pl.subplot(1,2,1) -pl.imshow(I1) -pl.title('Image 1') - -pl.subplot(1,2,2) -pl.imshow(I2) -pl.title('Image 2') - -pl.show() - -#%% Image conversion and dataset generation - -def im2mat(I): - """Converts and image to matrix (one pixel per line)""" - return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) - -def mat2im(X,shape): - """Converts back a matrix to an image""" - return X.reshape(shape) - -X1=im2mat(I1) -X2=im2mat(I2) - -# training samples -nb=1000 -idx1=np.random.randint(X1.shape[0],size=(nb,)) -idx2=np.random.randint(X2.shape[0],size=(nb,)) - -xs=X1[idx1,:] -xt=X2[idx2,:] - -#%% Plot image distributions - - -pl.figure(2,(10,5)) - -pl.subplot(1,2,1) -pl.scatter(xs[:,0],xs[:,2],c=xs) -pl.axis([0,1,0,1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') - -pl.subplot(1,2,2) -#pl.imshow(I2) -pl.scatter(xt[:,0],xt[:,2],c=xt) -pl.axis([0,1,0,1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') - -pl.show() - - - -#%% domain adaptation between images - -# LP problem -da_emd=ot.da.OTDA() # init class -da_emd.fit(xs,xt) # fit distributions - - -# sinkhorn regularization -lambd=1e-1 -da_entrop=ot.da.OTDA_sinkhorn() -da_entrop.fit(xs,xt,reg=lambd) - - - -#%% prediction between images (using out of sample prediction as in [6]) - -X1t=da_emd.predict(X1) -X2t=da_emd.predict(X2,-1) - - -X1te=da_entrop.predict(X1) -X2te=da_entrop.predict(X2,-1) - - -def minmax(I): - return np.minimum(np.maximum(I,0),1) - -I1t=minmax(mat2im(X1t,I1.shape)) -I2t=minmax(mat2im(X2t,I2.shape)) - -I1te=minmax(mat2im(X1te,I1.shape)) -I2te=minmax(mat2im(X2te,I2.shape)) - -#%% plot all images - -pl.figure(2,(10,8)) - -pl.subplot(2,3,1) - -pl.imshow(I1) -pl.title('Image 1') - -pl.subplot(2,3,2) -pl.imshow(I1t) -pl.title('Image 1 Adapt') - - -pl.subplot(2,3,3) -pl.imshow(I1te) -pl.title('Image 1 Adapt (reg)') - -pl.subplot(2,3,4) - -pl.imshow(I2) -pl.title('Image 2') - -pl.subplot(2,3,5) -pl.imshow(I2t) -pl.title('Image 2 Adapt') - - -pl.subplot(2,3,6) -pl.imshow(I2te) -pl.title('Image 2 Adapt (reg)') - -pl.show() \ No newline at end of file diff --git a/examples/demo_OTDA_mapping.py b/examples/demo_OTDA_mapping.py deleted file mode 100644 index 50c1df2..0000000 --- a/examples/demo_OTDA_mapping.py +++ /dev/null @@ -1,113 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Demo of OT mapping estimation for domain adaptation [8] - -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for - discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. -""" - -import numpy as np -import matplotlib.pylab as pl -import ot - - - -#%% dataset generation - -np.random.seed(0) # makes example reproducible - -n=100 # nb samples in source and target datasets -theta=2*np.pi/20 -nz=0.1 -xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) -xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) - -# one of the target mode changes its variance (no linear mapping) -xt[yt==2]*=3 -xt=xt+4 - - -#%% plot samples - -pl.figure(1,(8,5)) -pl.clf() - -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') - -pl.legend(loc=0) -pl.title('Source and target distributions') - - - -#%% OT linear mapping estimation - -eta=1e-8 # quadratic regularization for regression -mu=1e0 # weight of the OT linear term -bias=True # estimate a bias - -ot_mapping=ot.da.OTDA_mapping_linear() -ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) - -xst=ot_mapping.predict(xs) # use the estimated mapping -xst0=ot_mapping.interp() # use barycentric mapping - - -pl.figure(2,(10,7)) -pl.clf() -pl.subplot(2,2,1) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') -pl.title("barycentric mapping") - -pl.subplot(2,2,2) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) -pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') -pl.title("Learned mapping") - - - -#%% Kernel mapping estimation - -eta=1e-5 # quadratic regularization for regression -mu=1e-1 # weight of the OT linear term -bias=True # estimate a bias -sigma=1 # sigma bandwidth fot gaussian kernel - - -ot_mapping_kernel=ot.da.OTDA_mapping_kernel() -ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) - -xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping -xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping - - -#%% Plotting the mapped samples - -pl.figure(2,(10,7)) -pl.clf() -pl.subplot(2,2,1) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') -pl.title("Bary. mapping (linear)") -pl.legend(loc=0) - -pl.subplot(2,2,2) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) -pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') -pl.title("Estim. mapping (linear)") - -pl.subplot(2,2,3) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) -pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') -pl.title("Bary. mapping (kernel)") - -pl.subplot(2,2,4) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) -pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') -pl.title("Estim. mapping (kernel)") - - - - - diff --git a/examples/demo_OTDA_mapping_color_images.py b/examples/demo_OTDA_mapping_color_images.py deleted file mode 100644 index 2744f6c..0000000 --- a/examples/demo_OTDA_mapping_color_images.py +++ /dev/null @@ -1,157 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8] - -[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized - discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for - discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. - - -""" - -import numpy as np -import scipy.ndimage as spi -import matplotlib.pylab as pl -import ot - - -#%% Loading images - -I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 -I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 - -#%% Plot images - -pl.figure(1) - -pl.subplot(1,2,1) -pl.imshow(I1) -pl.title('Image 1') - -pl.subplot(1,2,2) -pl.imshow(I2) -pl.title('Image 2') - -pl.show() - -#%% Image conversion and dataset generation - -def im2mat(I): - """Converts and image to matrix (one pixel per line)""" - return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) - -def mat2im(X,shape): - """Converts back a matrix to an image""" - return X.reshape(shape) - -X1=im2mat(I1) -X2=im2mat(I2) - -# training samples -nb=1000 -idx1=np.random.randint(X1.shape[0],size=(nb,)) -idx2=np.random.randint(X2.shape[0],size=(nb,)) - -xs=X1[idx1,:] -xt=X2[idx2,:] - -#%% Plot image distributions - - -pl.figure(2,(10,5)) - -pl.subplot(1,2,1) -pl.scatter(xs[:,0],xs[:,2],c=xs) -pl.axis([0,1,0,1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') - -pl.subplot(1,2,2) -#pl.imshow(I2) -pl.scatter(xt[:,0],xt[:,2],c=xt) -pl.axis([0,1,0,1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') - -pl.show() - - - -#%% domain adaptation between images -def minmax(I): - return np.minimum(np.maximum(I,0),1) -# LP problem -da_emd=ot.da.OTDA() # init class -da_emd.fit(xs,xt) # fit distributions - -X1t=da_emd.predict(X1) # out of sample -I1t=minmax(mat2im(X1t,I1.shape)) - -# sinkhorn regularization -lambd=1e-1 -da_entrop=ot.da.OTDA_sinkhorn() -da_entrop.fit(xs,xt,reg=lambd) - -X1te=da_entrop.predict(X1) -I1te=minmax(mat2im(X1te,I1.shape)) - -# linear mapping estimation -eta=1e-8 # quadratic regularization for regression -mu=1e0 # weight of the OT linear term -bias=True # estimate a bias - -ot_mapping=ot.da.OTDA_mapping_linear() -ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) - -X1tl=ot_mapping.predict(X1) # use the estimated mapping -I1tl=minmax(mat2im(X1tl,I1.shape)) - -# nonlinear mapping estimation -eta=1e-2 # quadratic regularization for regression -mu=1e0 # weight of the OT linear term -bias=False # estimate a bias -sigma=1 # sigma bandwidth fot gaussian kernel - - -ot_mapping_kernel=ot.da.OTDA_mapping_kernel() -ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) - -X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping -I1tn=minmax(mat2im(X1tn,I1.shape)) -#%% plot images - - -pl.figure(2,(10,8)) - -pl.subplot(2,3,1) - -pl.imshow(I1) -pl.title('Im. 1') - -pl.subplot(2,3,2) - -pl.imshow(I2) -pl.title('Im. 2') - - -pl.subplot(2,3,3) -pl.imshow(I1t) -pl.title('Im. 1 Interp LP') - -pl.subplot(2,3,4) -pl.imshow(I1te) -pl.title('Im. 1 Interp Entrop') - - -pl.subplot(2,3,5) -pl.imshow(I1tl) -pl.title('Im. 1 Linear mapping') - -pl.subplot(2,3,6) -pl.imshow(I1tn) -pl.title('Im. 1 nonlinear mapping') - -pl.show() \ No newline at end of file diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py deleted file mode 100644 index df65a60..0000000 --- a/examples/demo_OT_1D.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Demo for 1D optimal transport - -@author: rflamary -""" - -import numpy as np -import matplotlib.pylab as pl -import ot -from ot.datasets import get_1D_gauss as gauss - - -#%% parameters - -n=100 # nb bins - -# bin positions -x=np.arange(n,dtype=np.float64) - -# Gaussian distributions -a=gauss(n,m=20,s=5) # m= mean, s= std -b=gauss(n,m=60,s=10) - -# loss matrix -M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) -M/=M.max() - -#%% plot the distributions - -pl.figure(1) -pl.plot(x,a,'b',label='Source distribution') -pl.plot(x,b,'r',label='Target distribution') -pl.legend() - -#%% plot distributions and loss matrix - -pl.figure(2) -ot.plot.plot1D_mat(a,b,M,'Cost matrix M') - -#%% EMD - -G0=ot.emd(a,b,M) - -pl.figure(3) -ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') - -#%% Sinkhorn - -lambd=1e-3 -Gs=ot.sinkhorn(a,b,M,lambd) - -pl.figure(4) -ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py deleted file mode 100644 index e9ec78a..0000000 --- a/examples/demo_OT_2D_samples.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Demo for 2D Optimal transport between empirical distributions - -@author: rflamary -""" - -import numpy as np -import matplotlib.pylab as pl -import ot - -#%% parameters and data generation - -n=20 # nb samples - -mu_s=np.array([0,0]) -cov_s=np.array([[1,0],[0,1]]) - -mu_t=np.array([4,4]) -cov_t=np.array([[1,-.8],[-.8,1]]) - -xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) -xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) - -a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples - -# loss matrix -M=ot.dist(xs,xt) -M/=M.max() - -#%% plot samples - -pl.figure(1) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('Source and traget distributions') - -pl.figure(2) -pl.imshow(M,interpolation='nearest') -pl.title('Cost matrix M') - - -#%% EMD - -G0=ot.emd(a,b,M) - -pl.figure(3) -pl.imshow(G0,interpolation='nearest') -pl.title('OT matrix G0') - -pl.figure(4) -ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('OT matrix with samples') - - -#%% sinkhorn - -# reg term -lambd=5e-3 - -Gs=ot.sinkhorn(a,b,M,lambd) - -pl.figure(5) -pl.imshow(Gs,interpolation='nearest') -pl.title('OT matrix sinkhorn') - -pl.figure(6) -ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) -pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') -pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') -pl.legend(loc=0) -pl.title('OT matrix Sinkhorn with samples') - - diff --git a/examples/demo_barycenter_1D.py b/examples/demo_barycenter_1D.py deleted file mode 100644 index 5466332..0000000 --- a/examples/demo_barycenter_1D.py +++ /dev/null @@ -1,135 +0,0 @@ -# -*- coding: utf-8 -*- -""" -1D Wasserstein barycenter demo - -@author: rflamary -""" - -import numpy as np -import matplotlib.pylab as pl -import ot -from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used -from matplotlib.collections import PolyCollection - - -#%% parameters - -n=100 # nb bins - -# bin positions -x=np.arange(n,dtype=np.float64) - -# Gaussian distributions -a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std -a2=ot.datasets.get_1D_gauss(n,m=60,s=8) - -# creating matrix A containing all distributions -A=np.vstack((a1,a2)).T -nbd=A.shape[1] - -# loss matrix + normalization -M=ot.utils.dist0(n) -M/=M.max() - -#%% plot the distributions - -pl.figure(1) -for i in range(nbd): - pl.plot(x,A[:,i]) -pl.title('Distributions') - -#%% barycenter computation - -alpha=0.2 # 0<=alpha<=1 -weights=np.array([1-alpha,alpha]) - -# l2bary -bary_l2=A.dot(weights) - -# wasserstein -reg=1e-3 -bary_wass=ot.bregman.barycenter(A,M,reg,weights) - -pl.figure(2) -pl.clf() -pl.subplot(2,1,1) -for i in range(nbd): - pl.plot(x,A[:,i]) -pl.title('Distributions') - -pl.subplot(2,1,2) -pl.plot(x,bary_l2,'r',label='l2') -pl.plot(x,bary_wass,'g',label='Wasserstein') -pl.legend() -pl.title('Barycenters') - - -#%% barycenter interpolation - -nbalpha=11 -alphalist=np.linspace(0,1,nbalpha) - - -B_l2=np.zeros((n,nbalpha)) - -B_wass=np.copy(B_l2) - -for i in range(0,nbalpha): - alpha=alphalist[i] - weights=np.array([1-alpha,alpha]) - B_l2[:,i]=A.dot(weights) - B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) - -#%% plot interpolation - -pl.figure(3,(10,5)) - -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') -verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_l2[:,i] - verts.append(list(zip(x, ys))) - -ax = pl.gcf().gca(projection='3d') - -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) -poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') - -ax.set_xlabel('x') -ax.set_xlim3d(0, n) -ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) -ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) -pl.title('Barycenter interpolation with l2') - -pl.show() - -pl.figure(4,(10,5)) - -#pl.subplot(1,2,1) -cmap=pl.cm.get_cmap('viridis') -verts = [] -zs = alphalist -for i,z in enumerate(zs): - ys = B_wass[:,i] - verts.append(list(zip(x, ys))) - -ax = pl.gcf().gca(projection='3d') - -poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) -poly.set_alpha(0.7) -ax.add_collection3d(poly, zs=zs, zdir='y') - -ax.set_xlabel('x') -ax.set_xlim3d(0, n) -ax.set_ylabel('$\\alpha$') -ax.set_ylim3d(0,1) -ax.set_zlabel('') -ax.set_zlim3d(0, B_l2.max()*1.01) -pl.title('Barycenter interpolation with Wasserstein') - -pl.show() \ No newline at end of file diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py deleted file mode 100644 index 0de6b08..0000000 --- a/examples/demo_optim_OTreg.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Regularized OT with generic solver -""" - -import numpy as np -import matplotlib.pylab as pl -import ot - - - -#%% parameters - -n=100 # nb bins - -# bin positions -x=np.arange(n,dtype=np.float64) - -# Gaussian distributions -a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std -b=ot.datasets.get_1D_gauss(n,m=60,s=10) - -# loss matrix -M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) -M/=M.max() - -#%% EMD - -G0=ot.emd(a,b,M) - -pl.figure(3) -ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') - -#%% Example with Frobenius norm regularization - -def f(G): return 0.5*np.sum(G**2) -def df(G): return G - -reg=1e-1 - -Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True) - -pl.figure(3) -ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg') - -#%% Example with entropic regularization - -def f(G): return np.sum(G*np.log(G)) -def df(G): return np.log(G)+1 - -reg=1e-3 - -Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True) - -pl.figure(4) -ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg') - -#%% Example with Frobenius norm + entropic regularization with gcg - -def f(G): return 0.5*np.sum(G**2) -def df(G): return G - -reg1=1e-1 -reg2=1e-1 - -Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) - -pl.figure(5) -ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg') \ No newline at end of file diff --git a/examples/plot_OTDA_2D.py b/examples/plot_OTDA_2D.py new file mode 100644 index 0000000..db04c5e --- /dev/null +++ b/examples/plot_OTDA_2D.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +demo of Optimal transport for domain adaptation +=============================================== + +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=150 # nb bins + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + +a,b = ot.unif(n),ot.unif(n) +# loss matrix +M=ot.dist(xs,xt) +#M/=M.max() + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% OT estimation + +# EMD +G0=ot.emd(a,b,M) + +# sinkhorn +lambd=1e-1 +Gs=ot.sinkhorn(a,b,M,lambd) + + +# Group lasso regularization +reg=1e-1 +eta=1e0 +Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) + + +#%% visu matrices + +pl.figure(3) + +pl.subplot(2,3,1) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix ') + +pl.subplot(2,3,2) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix Sinkhorn') + +pl.subplot(2,3,3) +pl.imshow(Gg,interpolation='nearest') +pl.title('OT matrix Group lasso') + +pl.subplot(2,3,4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + +pl.subplot(2,3,5) +ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.subplot(2,3,6) +ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +#%% sample interpolation + +xst0=n*G0.dot(xt) +xsts=n*Gs.dot(xt) +xstg=n*Gg.dot(xt) + +pl.figure(4) +pl.subplot(2,3,1) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,3,2) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,3,3) + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Grouplasso') \ No newline at end of file diff --git a/examples/plot_OTDA_classes.py b/examples/plot_OTDA_classes.py new file mode 100644 index 0000000..999be53 --- /dev/null +++ b/examples/plot_OTDA_classes.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +demo of Optimal transport for domain adaptation +=============================================== + +""" + +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=150 # nb samples in source and target datasets + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + + + + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + + +#%% OT estimation + +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions +xst0=da_emd.interp() # interpolation of source samples + + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) +xsts=da_entrop.interp() + +# non-convex Group lasso regularization +reg=1e-1 +eta=1e0 +da_lpl1=ot.da.OTDA_lpl1() +da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) +xstg=da_lpl1.interp() + + +# True Group lasso regularization +reg=1e-1 +eta=2e0 +da_l1l2=ot.da.OTDA_l1l2() +da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) +xstgl=da_l1l2.interp() + + +#%% plot interpolated source samples +pl.figure(4,(15,8)) + +param_img={'interpolation':'nearest','cmap':'jet'} + +pl.subplot(2,4,1) +pl.imshow(da_emd.G,**param_img) +pl.title('OT matrix') + + +pl.subplot(2,4,2) +pl.imshow(da_entrop.G,**param_img) +pl.title('OT matrix sinkhorn') + +pl.subplot(2,4,3) +pl.imshow(da_lpl1.G,**param_img) +pl.title('OT matrix non-convex Group Lasso') + +pl.subplot(2,4,4) +pl.imshow(da_l1l2.G,**param_img) +pl.title('OT matrix Group Lasso') + + +pl.subplot(2,4,5) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,4,6) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,4,7) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples non-convex Group Lasso') + +pl.subplot(2,4,8) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Group Lasso') \ No newline at end of file diff --git a/examples/plot_OTDA_color_images.py b/examples/plot_OTDA_color_images.py new file mode 100644 index 0000000..7ddaa2b --- /dev/null +++ b/examples/plot_OTDA_color_images.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +""" +===================================================================================== +Demo of Optimal transport for domain adaptation with image color adaptation as in [6] +===================================================================================== + +[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +""" + +import numpy as np +import scipy.ndimage as spi +import matplotlib.pylab as pl +import ot + + +#%% Loading images + +I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 +I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + +#%% Plot images + +pl.figure(1) + +pl.subplot(1,2,1) +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(1,2,2) +pl.imshow(I2) +pl.title('Image 2') + +pl.show() + +#%% Image conversion and dataset generation + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + +def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + +X1=im2mat(I1) +X2=im2mat(I2) + +# training samples +nb=1000 +idx1=np.random.randint(X1.shape[0],size=(nb,)) +idx2=np.random.randint(X2.shape[0],size=(nb,)) + +xs=X1[idx1,:] +xt=X2[idx2,:] + +#%% Plot image distributions + + +pl.figure(2,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xs[:,0],xs[:,2],c=xs) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 1') + +pl.subplot(1,2,2) +#pl.imshow(I2) +pl.scatter(xt[:,0],xt[:,2],c=xt) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 2') + +pl.show() + + + +#%% domain adaptation between images + +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions + + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) + + + +#%% prediction between images (using out of sample prediction as in [6]) + +X1t=da_emd.predict(X1) +X2t=da_emd.predict(X2,-1) + + +X1te=da_entrop.predict(X1) +X2te=da_entrop.predict(X2,-1) + + +def minmax(I): + return np.minimum(np.maximum(I,0),1) + +I1t=minmax(mat2im(X1t,I1.shape)) +I2t=minmax(mat2im(X2t,I2.shape)) + +I1te=minmax(mat2im(X1te,I1.shape)) +I2te=minmax(mat2im(X2te,I2.shape)) + +#%% plot all images + +pl.figure(2,(10,8)) + +pl.subplot(2,3,1) + +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(2,3,2) +pl.imshow(I1t) +pl.title('Image 1 Adapt') + + +pl.subplot(2,3,3) +pl.imshow(I1te) +pl.title('Image 1 Adapt (reg)') + +pl.subplot(2,3,4) + +pl.imshow(I2) +pl.title('Image 2') + +pl.subplot(2,3,5) +pl.imshow(I2t) +pl.title('Image 2 Adapt') + + +pl.subplot(2,3,6) +pl.imshow(I2te) +pl.title('Image 2 Adapt (reg)') + +pl.show() \ No newline at end of file diff --git a/examples/plot_OTDA_mapping.py b/examples/plot_OTDA_mapping.py new file mode 100644 index 0000000..779fa76 --- /dev/null +++ b/examples/plot_OTDA_mapping.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +======================================================= +Demo of OT mapping estimation for domain adaptation [8] +======================================================= + +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% dataset generation + +np.random.seed(0) # makes example reproducible + +n=100 # nb samples in source and target datasets +theta=2*np.pi/20 +nz=0.1 +xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) +xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + +# one of the target mode changes its variance (no linear mapping) +xt[yt==2]*=3 +xt=xt+4 + + +#%% plot samples + +pl.figure(1,(8,5)) +pl.clf() + +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.legend(loc=0) +pl.title('Source and target distributions') + + + +#%% OT linear mapping estimation + +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +xst=ot_mapping.predict(xs) # use the estimated mapping +xst0=ot_mapping.interp() # use barycentric mapping + + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("barycentric mapping") + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Learned mapping") + + + +#%% Kernel mapping estimation + +eta=1e-5 # quadratic regularization for regression +mu=1e-1 # weight of the OT linear term +bias=True # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping +xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + +#%% Plotting the mapped samples + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') +pl.title("Bary. mapping (linear)") +pl.legend(loc=0) + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (linear)") + +pl.subplot(2,2,3) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("Bary. mapping (kernel)") + +pl.subplot(2,2,4) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (kernel)") diff --git a/examples/plot_OTDA_mapping_color_images.py b/examples/plot_OTDA_mapping_color_images.py new file mode 100644 index 0000000..0cd6c9c --- /dev/null +++ b/examples/plot_OTDA_mapping_color_images.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +""" +====================================================================================================================== +Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8] +====================================================================================================================== + +[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized + discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + +""" + +import numpy as np +import scipy.ndimage as spi +import matplotlib.pylab as pl +import ot + + +#%% Loading images + +I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 +I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + +#%% Plot images + +pl.figure(1) + +pl.subplot(1,2,1) +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(1,2,2) +pl.imshow(I2) +pl.title('Image 2') + +pl.show() + +#%% Image conversion and dataset generation + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + +def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + +X1=im2mat(I1) +X2=im2mat(I2) + +# training samples +nb=1000 +idx1=np.random.randint(X1.shape[0],size=(nb,)) +idx2=np.random.randint(X2.shape[0],size=(nb,)) + +xs=X1[idx1,:] +xt=X2[idx2,:] + +#%% Plot image distributions + + +pl.figure(2,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xs[:,0],xs[:,2],c=xs) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 1') + +pl.subplot(1,2,2) +#pl.imshow(I2) +pl.scatter(xt[:,0],xt[:,2],c=xt) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 2') + +pl.show() + + + +#%% domain adaptation between images +def minmax(I): + return np.minimum(np.maximum(I,0),1) +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions + +X1t=da_emd.predict(X1) # out of sample +I1t=minmax(mat2im(X1t,I1.shape)) + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) + +X1te=da_entrop.predict(X1) +I1te=minmax(mat2im(X1te,I1.shape)) + +# linear mapping estimation +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +X1tl=ot_mapping.predict(X1) # use the estimated mapping +I1tl=minmax(mat2im(X1tl,I1.shape)) + +# nonlinear mapping estimation +eta=1e-2 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=False # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping +I1tn=minmax(mat2im(X1tn,I1.shape)) +#%% plot images + + +pl.figure(2,(10,8)) + +pl.subplot(2,3,1) + +pl.imshow(I1) +pl.title('Im. 1') + +pl.subplot(2,3,2) + +pl.imshow(I2) +pl.title('Im. 2') + + +pl.subplot(2,3,3) +pl.imshow(I1t) +pl.title('Im. 1 Interp LP') + +pl.subplot(2,3,4) +pl.imshow(I1te) +pl.title('Im. 1 Interp Entrop') + + +pl.subplot(2,3,5) +pl.imshow(I1tl) +pl.title('Im. 1 Linear mapping') + +pl.subplot(2,3,6) +pl.imshow(I1tn) +pl.title('Im. 1 nonlinear mapping') + +pl.show() \ No newline at end of file diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py new file mode 100644 index 0000000..a8bbbd6 --- /dev/null +++ b/examples/plot_OT_1D.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +""" +============================= +Demo for 1D optimal transport +============================= + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import get_1D_gauss as gauss + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=gauss(n,m=20,s=5) # m= mean, s= std +b=gauss(n,m=60,s=10) + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) +pl.plot(x,a,'b',label='Source distribution') +pl.plot(x,b,'r',label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2) +ot.plot.plot1D_mat(a,b,M,'Cost matrix M') + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + +#%% Sinkhorn + +lambd=1e-3 +Gs=ot.sinkhorn(a,b,M,lambd) + +pl.figure(4) +ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py new file mode 100644 index 0000000..e55eff8 --- /dev/null +++ b/examples/plot_OT_2D_samples.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +Demo for 2D Optimal transport between empirical distributions + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + +#%% parameters and data generation + +n=20 # nb samples + +mu_s=np.array([0,0]) +cov_s=np.array([[1,0],[0,1]]) + +mu_t=np.array([4,4]) +cov_t=np.array([[1,-.8],[-.8,1]]) + +xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) +xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t) + +a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples + +# loss matrix +M=ot.dist(xs,xt) +M/=M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('Source and traget distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix G0') + +pl.figure(4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix with samples') + + +#%% sinkhorn + +# reg term +lambd=5e-3 + +Gs=ot.sinkhorn(a,b,M,lambd) + +pl.figure(5) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix sinkhorn') + +pl.figure(6) +ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1]) +pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples') +pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix Sinkhorn with samples') diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py new file mode 100644 index 0000000..5466332 --- /dev/null +++ b/examples/plot_barycenter_1D.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +""" +1D Wasserstein barycenter demo + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +from matplotlib.collections import PolyCollection + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +a2=ot.datasets.get_1D_gauss(n,m=60,s=8) + +# creating matrix A containing all distributions +A=np.vstack((a1,a2)).T +nbd=A.shape[1] + +# loss matrix + normalization +M=ot.utils.dist0(n) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) +for i in range(nbd): + pl.plot(x,A[:,i]) +pl.title('Distributions') + +#%% barycenter computation + +alpha=0.2 # 0<=alpha<=1 +weights=np.array([1-alpha,alpha]) + +# l2bary +bary_l2=A.dot(weights) + +# wasserstein +reg=1e-3 +bary_wass=ot.bregman.barycenter(A,M,reg,weights) + +pl.figure(2) +pl.clf() +pl.subplot(2,1,1) +for i in range(nbd): + pl.plot(x,A[:,i]) +pl.title('Distributions') + +pl.subplot(2,1,2) +pl.plot(x,bary_l2,'r',label='l2') +pl.plot(x,bary_wass,'g',label='Wasserstein') +pl.legend() +pl.title('Barycenters') + + +#%% barycenter interpolation + +nbalpha=11 +alphalist=np.linspace(0,1,nbalpha) + + +B_l2=np.zeros((n,nbalpha)) + +B_wass=np.copy(B_l2) + +for i in range(0,nbalpha): + alpha=alphalist[i] + weights=np.array([1-alpha,alpha]) + B_l2[:,i]=A.dot(weights) + B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights) + +#%% plot interpolation + +pl.figure(3,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_l2[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with l2') + +pl.show() + +pl.figure(4,(10,5)) + +#pl.subplot(1,2,1) +cmap=pl.cm.get_cmap('viridis') +verts = [] +zs = alphalist +for i,z in enumerate(zs): + ys = B_wass[:,i] + verts.append(list(zip(x, ys))) + +ax = pl.gcf().gca(projection='3d') + +poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist]) +poly.set_alpha(0.7) +ax.add_collection3d(poly, zs=zs, zdir='y') + +ax.set_xlabel('x') +ax.set_xlim3d(0, n) +ax.set_ylabel('$\\alpha$') +ax.set_ylim3d(0,1) +ax.set_zlabel('') +ax.set_zlim3d(0, B_l2.max()*1.01) +pl.title('Barycenter interpolation with Wasserstein') + +pl.show() \ No newline at end of file diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py new file mode 100644 index 0000000..0de6b08 --- /dev/null +++ b/examples/plot_optim_OTreg.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +Regularized OT with generic solver +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=100 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std +b=ot.datasets.get_1D_gauss(n,m=60,s=10) + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% EMD + +G0=ot.emd(a,b,M) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') + +#%% Example with Frobenius norm regularization + +def f(G): return 0.5*np.sum(G**2) +def df(G): return G + +reg=1e-1 + +Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + +pl.figure(3) +ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg') + +#%% Example with entropic regularization + +def f(G): return np.sum(G*np.log(G)) +def df(G): return np.log(G)+1 + +reg=1e-3 + +Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True) + +pl.figure(4) +ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg') + +#%% Example with Frobenius norm + entropic regularization with gcg + +def f(G): return 0.5*np.sum(G**2) +def df(G): return G + +reg1=1e-1 +reg2=1e-1 + +Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) + +pl.figure(5) +ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg') \ No newline at end of file -- cgit v1.2.3