From 20547ee48dbc5dd0f3127983e22c48e18260e35f Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 6 Jul 2017 09:17:52 +0200 Subject: update wda example --- examples/plot_WDA.py | 54 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 13 deletions(-) (limited to 'examples/plot_WDA.py') diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py index d2eaf6d..d96b27f 100644 --- a/examples/plot_WDA.py +++ b/examples/plot_WDA.py @@ -18,8 +18,17 @@ from ot.dr import wda, fda n=1000 # nb samples in source and target datasets nz=0.2 -xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) -xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) + +# generate circle dataset +t=np.random.rand(n)*2*np.pi +ys=np.floor((np.arange(n)*1.0/n*3))+1 +xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) +xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2) + +t=np.random.rand(n)*2*np.pi +yt=np.floor((np.arange(n)*1.0/n*3))+1 +xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) +xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2) nbnoise=8 @@ -27,42 +36,61 @@ xs=np.hstack((xs,np.random.randn(n,nbnoise))) xt=np.hstack((xt,np.random.randn(n,nbnoise))) #%% plot samples +pl.figure(1,(10,5)) -pl.figure(1) - - +pl.subplot(1,2,1) pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') +pl.subplot(1,2,2) +pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Other dimensions') +pl.show() -#%% Comlpute FDA +#%% Compute FDA p=2 Pfda,projfda = fda(xs,ys,p) #%% Compute WDA p=2 -reg=1 +reg=1e-1 k=10 maxiter=100 -P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) +Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter) #%% plot samples xsp=projfda(xs) xtp=projfda(xt) -pl.figure(1,(10,5)) +xspw=projwda(xs) +xtpw=projwda(xt) -pl.subplot(1,2,1) +pl.figure(1,(10,10)) + +pl.subplot(2,2,1) pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') pl.legend(loc=0) -pl.title('Projected training samples') +pl.title('Projected training samples FDA') -pl.subplot(1,2,2) +pl.subplot(2,2,2) pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') pl.legend(loc=0) -pl.title('Projected test samples') +pl.title('Projected test samples FDA') + + +pl.subplot(2,2,3) +pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples') +pl.legend(loc=0) +pl.title('Projected training samples WDA') + + +pl.subplot(2,2,4) +pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples') +pl.legend(loc=0) +pl.title('Projected test samples WDA') -- cgit v1.2.3