summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2016-11-08 23:15:49 +0100
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2016-11-08 23:15:49 +0100
commit22036bd9db2cd5cc8329b27ca740ff4d9c114fb7 (patch)
treeb964586e8069b7197652febcafb6ab2918c0d33a /examples
parent0d1f3eb3c41c0b06edc70647037b6cda581e8e2d (diff)
da with GL
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_OTDA_classes.py37
-rw-r--r--examples/demo_optim_OTreg.py4
2 files changed, 30 insertions, 11 deletions
diff --git a/examples/demo_OTDA_classes.py b/examples/demo_OTDA_classes.py
index 8a97c80..43fd37e 100644
--- a/examples/demo_OTDA_classes.py
+++ b/examples/demo_OTDA_classes.py
@@ -48,43 +48,62 @@ da_entrop=ot.da.OTDA_sinkhorn()
da_entrop.fit(xs,xt,reg=lambd)
xsts=da_entrop.interp()
-# Group lasso regularization
+# non-convex Group lasso regularization
reg=1e-1
eta=1e0
da_lpl1=ot.da.OTDA_lpl1()
-da_lpl1.fit(xs,ys,xt,reg=lambd,eta=eta)
+da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
xstg=da_lpl1.interp()
+
+# True Group lasso regularization
+reg=1e-1
+eta=1e1
+da_l1l2=ot.da.OTDA_l1l2()
+da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
+xstgl=da_l1l2.interp()
+
+
#%% plot interpolated source samples
-pl.figure(4,(15,10))
+pl.figure(4,(15,8))
param_img={'interpolation':'nearest','cmap':'jet'}
-pl.subplot(2,3,1)
+pl.subplot(2,4,1)
pl.imshow(da_emd.G,**param_img)
pl.title('OT matrix')
-pl.subplot(2,3,2)
+pl.subplot(2,4,2)
pl.imshow(da_entrop.G,**param_img)
pl.title('OT matrix sinkhorn')
-pl.subplot(2,3,3)
+pl.subplot(2,4,3)
pl.imshow(da_lpl1.G,**param_img)
+pl.title('OT matrix non-convex Group Lasso')
+
+pl.subplot(2,4,4)
+pl.imshow(da_l1l2.G,**param_img)
pl.title('OT matrix Group Lasso')
-pl.subplot(2,3,4)
+
+pl.subplot(2,4,5)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples')
pl.legend(loc=0)
-pl.subplot(2,3,5)
+pl.subplot(2,4,6)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples Sinkhorn')
-pl.subplot(2,3,6)
+pl.subplot(2,4,7)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
+pl.title('Interp samples non-convex Group Lasso')
+
+pl.subplot(2,4,8)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples Group Lasso') \ No newline at end of file
diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py
index a49b00c..0de6b08 100644
--- a/examples/demo_optim_OTreg.py
+++ b/examples/demo_optim_OTreg.py
@@ -60,8 +60,8 @@ ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')
def f(G): return 0.5*np.sum(G**2)
def df(G): return G
-reg1=1e-3
-reg2=1e-3
+reg1=1e-1
+reg2=1e-1
Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)