diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-09-15 14:54:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-09-15 14:54:21 +0200 |
commit | 81b2796226f3abde29fc024752728444da77509a (patch) | |
tree | c52cec3c38552f9f8c15361758aa9a80c30c3ef3 /docs/source/auto_examples/plot_OTDA_classes.py | |
parent | e70d5420204db78691af2d0fbe04cc3d4416a8f4 (diff) | |
parent | 7fea2cd3e8ad29bf3fa442d7642bae124ee2bab0 (diff) |
Merge pull request #27 from rflamary/autonb
auto notebooks + release update (fixes #16)
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_classes.py')
-rw-r--r-- | docs/source/auto_examples/plot_OTDA_classes.py | 112 |
1 files changed, 0 insertions, 112 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_classes.py b/docs/source/auto_examples/plot_OTDA_classes.py deleted file mode 100644 index 089b45b..0000000 --- a/docs/source/auto_examples/plot_OTDA_classes.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -""" -======================== -OT for domain adaptation -======================== - -""" - -import matplotlib.pylab as pl -import ot - - - - -#%% parameters - -n=150 # nb samples in source and target datasets - -xs,ys=ot.datasets.get_data_classif('3gauss',n) -xt,yt=ot.datasets.get_data_classif('3gauss2',n) - - - - -#%% plot samples - -pl.figure(1) - -pl.subplot(2,2,1) -pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') -pl.legend(loc=0) -pl.title('Source distributions') - -pl.subplot(2,2,2) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') -pl.legend(loc=0) -pl.title('target distributions') - - -#%% OT estimation - -# LP problem -da_emd=ot.da.OTDA() # init class -da_emd.fit(xs,xt) # fit distributions -xst0=da_emd.interp() # interpolation of source samples - - -# sinkhorn regularization -lambd=1e-1 -da_entrop=ot.da.OTDA_sinkhorn() -da_entrop.fit(xs,xt,reg=lambd) -xsts=da_entrop.interp() - -# non-convex Group lasso regularization -reg=1e-1 -eta=1e0 -da_lpl1=ot.da.OTDA_lpl1() -da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) -xstg=da_lpl1.interp() - - -# True Group lasso regularization -reg=1e-1 -eta=2e0 -da_l1l2=ot.da.OTDA_l1l2() -da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) -xstgl=da_l1l2.interp() - - -#%% plot interpolated source samples -pl.figure(4,(15,8)) - -param_img={'interpolation':'nearest','cmap':'jet'} - -pl.subplot(2,4,1) -pl.imshow(da_emd.G,**param_img) -pl.title('OT matrix') - - -pl.subplot(2,4,2) -pl.imshow(da_entrop.G,**param_img) -pl.title('OT matrix sinkhorn') - -pl.subplot(2,4,3) -pl.imshow(da_lpl1.G,**param_img) -pl.title('OT matrix non-convex Group Lasso') - -pl.subplot(2,4,4) -pl.imshow(da_l1l2.G,**param_img) -pl.title('OT matrix Group Lasso') - - -pl.subplot(2,4,5) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples') -pl.legend(loc=0) - -pl.subplot(2,4,6) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Sinkhorn') - -pl.subplot(2,4,7) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples non-convex Group Lasso') - -pl.subplot(2,4,8) -pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) -pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) -pl.title('Interp samples Group Lasso')
\ No newline at end of file |