From f439f777084690ecbf54bcd8d67dadc883fffa31 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Fri, 2 Dec 2016 12:49:38 +0100 Subject: first attempt to support sphinx-gallery --- examples/plot_OTDA_2D.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 examples/plot_OTDA_2D.py (limited to 'examples/plot_OTDA_2D.py') diff --git a/examples/plot_OTDA_2D.py b/examples/plot_OTDA_2D.py new file mode 100644 index 0000000..db04c5e --- /dev/null +++ b/examples/plot_OTDA_2D.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +demo of Optimal transport for domain adaptation +=============================================== + +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=150 # nb bins + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + +a,b = ot.unif(n),ot.unif(n) +# loss matrix +M=ot.dist(xs,xt) +#M/=M.max() + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% OT estimation + +# EMD +G0=ot.emd(a,b,M) + +# sinkhorn +lambd=1e-1 +Gs=ot.sinkhorn(a,b,M,lambd) + + +# Group lasso regularization +reg=1e-1 +eta=1e0 +Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) + + +#%% visu matrices + +pl.figure(3) + +pl.subplot(2,3,1) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix ') + +pl.subplot(2,3,2) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix Sinkhorn') + +pl.subplot(2,3,3) +pl.imshow(Gg,interpolation='nearest') +pl.title('OT matrix Group lasso') + +pl.subplot(2,3,4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + +pl.subplot(2,3,5) +ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.subplot(2,3,6) +ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +#%% sample interpolation + +xst0=n*G0.dot(xt) +xsts=n*Gs.dot(xt) +xstg=n*Gg.dot(xt) + +pl.figure(4) +pl.subplot(2,3,1) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,3,2) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,3,3) + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Grouplasso') \ No newline at end of file -- cgit v1.2.3