diff options
author | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2016-11-08 23:15:49 +0100 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2016-11-08 23:15:49 +0100 |
commit | 22036bd9db2cd5cc8329b27ca740ff4d9c114fb7 (patch) | |
tree | b964586e8069b7197652febcafb6ab2918c0d33a /examples | |
parent | 0d1f3eb3c41c0b06edc70647037b6cda581e8e2d (diff) |
da with GL
Diffstat (limited to 'examples')
-rw-r--r-- | examples/demo_OTDA_classes.py | 37 | ||||
-rw-r--r-- | examples/demo_optim_OTreg.py | 4 |
2 files changed, 30 insertions, 11 deletions
diff --git a/examples/demo_OTDA_classes.py b/examples/demo_OTDA_classes.py index 8a97c80..43fd37e 100644 --- a/examples/demo_OTDA_classes.py +++ b/examples/demo_OTDA_classes.py @@ -48,43 +48,62 @@ da_entrop=ot.da.OTDA_sinkhorn() da_entrop.fit(xs,xt,reg=lambd) xsts=da_entrop.interp() -# Group lasso regularization +# non-convex Group lasso regularization reg=1e-1 eta=1e0 da_lpl1=ot.da.OTDA_lpl1() -da_lpl1.fit(xs,ys,xt,reg=lambd,eta=eta) +da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) xstg=da_lpl1.interp() + +# True Group lasso regularization +reg=1e-1 +eta=1e1 +da_l1l2=ot.da.OTDA_l1l2() +da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) +xstgl=da_l1l2.interp() + + #%% plot interpolated source samples -pl.figure(4,(15,10)) +pl.figure(4,(15,8)) param_img={'interpolation':'nearest','cmap':'jet'} -pl.subplot(2,3,1) +pl.subplot(2,4,1) pl.imshow(da_emd.G,**param_img) pl.title('OT matrix') -pl.subplot(2,3,2) +pl.subplot(2,4,2) pl.imshow(da_entrop.G,**param_img) pl.title('OT matrix sinkhorn') -pl.subplot(2,3,3) +pl.subplot(2,4,3) pl.imshow(da_lpl1.G,**param_img) +pl.title('OT matrix non-convex Group Lasso') + +pl.subplot(2,4,4) +pl.imshow(da_l1l2.G,**param_img) pl.title('OT matrix Group Lasso') -pl.subplot(2,3,4) + +pl.subplot(2,4,5) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples') pl.legend(loc=0) -pl.subplot(2,3,5) +pl.subplot(2,4,6) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples Sinkhorn') -pl.subplot(2,3,6) +pl.subplot(2,4,7) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) +pl.title('Interp samples non-convex Group Lasso') + +pl.subplot(2,4,8) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) +pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples Group Lasso')
\ No newline at end of file diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py index a49b00c..0de6b08 100644 --- a/examples/demo_optim_OTreg.py +++ b/examples/demo_optim_OTreg.py @@ -60,8 +60,8 @@ ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg') def f(G): return 0.5*np.sum(G**2) def df(G): return G -reg1=1e-3 -reg2=1e-3 +reg1=1e-1 +reg2=1e-1 Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True) |