diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2016-12-02 12:49:38 +0100 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2016-12-02 12:57:25 +0100 |
commit | f439f777084690ecbf54bcd8d67dadc883fffa31 (patch) | |
tree | 56c8a160fb6edcdaca8ca6ce6de1949b9bc33b77 /examples/plot_OTDA_2D.py | |
parent | 8dbfd3edae649f5f3e87be4a3ce446c59729b2f7 (diff) |
first attempt to support sphinx-gallery
Diffstat (limited to 'examples/plot_OTDA_2D.py')
-rw-r--r-- | examples/plot_OTDA_2D.py | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/examples/plot_OTDA_2D.py b/examples/plot_OTDA_2D.py new file mode 100644 index 0000000..db04c5e --- /dev/null +++ b/examples/plot_OTDA_2D.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +demo of Optimal transport for domain adaptation +=============================================== + +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% parameters + +n=150 # nb bins + +xs,ys=ot.datasets.get_data_classif('3gauss',n) +xt,yt=ot.datasets.get_data_classif('3gauss2',n) + +a,b = ot.unif(n),ot.unif(n) +# loss matrix +M=ot.dist(xs,xt) +#M/=M.max() + +#%% plot samples + +pl.figure(1) + +pl.subplot(2,2,1) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Source distributions') + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') +pl.legend(loc=0) +pl.title('target distributions') + +pl.figure(2) +pl.imshow(M,interpolation='nearest') +pl.title('Cost matrix M') + + +#%% OT estimation + +# EMD +G0=ot.emd(a,b,M) + +# sinkhorn +lambd=1e-1 +Gs=ot.sinkhorn(a,b,M,lambd) + + +# Group lasso regularization +reg=1e-1 +eta=1e0 +Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta) + + +#%% visu matrices + +pl.figure(3) + +pl.subplot(2,3,1) +pl.imshow(G0,interpolation='nearest') +pl.title('OT matrix ') + +pl.subplot(2,3,2) +pl.imshow(Gs,interpolation='nearest') +pl.title('OT matrix Sinkhorn') + +pl.subplot(2,3,3) +pl.imshow(Gg,interpolation='nearest') +pl.title('OT matrix Group lasso') + +pl.subplot(2,3,4) +ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + + +pl.subplot(2,3,5) +ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.subplot(2,3,6) +ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1]) +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +#%% sample interpolation + +xst0=n*G0.dot(xt) +xsts=n*Gs.dot(xt) +xstg=n*Gg.dot(xt) + +pl.figure(4) +pl.subplot(2,3,1) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples') +pl.legend(loc=0) + +pl.subplot(2,3,2) + + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Sinkhorn') + +pl.subplot(2,3,3) + +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5) +pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples Grouplasso')
\ No newline at end of file |