From 6d7fd7e9faffa777cef222bdfc48c7ad732ab950 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 12 Jul 2017 23:16:00 +0200 Subject: more --- examples/plot_WDA.py | 86 +++++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 42 deletions(-) (limited to 'examples') diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py index 5fa2ab1..8a44022 100644 --- a/examples/plot_WDA.py +++ b/examples/plot_WDA.py @@ -16,81 +16,83 @@ from ot.dr import wda, fda #%% parameters -n=1000 # nb samples in source and target datasets -nz=0.2 +n = 1000 # nb samples in source and target datasets +nz = 0.2 # generate circle dataset -t=np.random.rand(n)*2*np.pi -ys=np.floor((np.arange(n)*1.0/n*3))+1 -xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) -xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2) +t = np.random.rand(n) * 2 * np.pi +ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 +xs = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) -t=np.random.rand(n)*2*np.pi -yt=np.floor((np.arange(n)*1.0/n*3))+1 -xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1) -xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2) +t = np.random.rand(n) * 2 * np.pi +yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 +xt = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) -nbnoise=8 +nbnoise = 8 -xs=np.hstack((xs,np.random.randn(n,nbnoise))) -xt=np.hstack((xt,np.random.randn(n,nbnoise))) +xs = np.hstack((xs, np.random.randn(n, nbnoise))) +xt = np.hstack((xt, np.random.randn(n, nbnoise))) #%% plot samples -pl.figure(1,(10,5)) +pl.figure(1, figsize=(6.4, 3.5)) -pl.subplot(1,2,1) -pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') +pl.subplot(1, 2, 1) +pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') -pl.subplot(1,2,2) -pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples') +pl.subplot(1, 2, 2) +pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Other dimensions') -pl.show() +pl.tight_layout() #%% Compute FDA -p=2 +p = 2 -Pfda,projfda = fda(xs,ys,p) +Pfda, projfda = fda(xs, ys, p) #%% Compute WDA -p=2 -reg=1e0 -k=10 -maxiter=100 +p = 2 +reg = 1e0 +k = 10 +maxiter = 100 -Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter) +Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) #%% plot samples -xsp=projfda(xs) -xtp=projfda(xt) +xsp = projfda(xs) +xtp = projfda(xt) -xspw=projwda(xs) -xtpw=projwda(xt) +xspw = projwda(xs) +xtpw = projwda(xt) -pl.figure(1,(10,10)) +pl.figure(2) -pl.subplot(2,2,1) -pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') +pl.subplot(2, 2, 1) +pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples FDA') - -pl.subplot(2,2,2) -pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') +pl.subplot(2, 2, 2) +pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples FDA') - -pl.subplot(2,2,3) -pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples') +pl.subplot(2, 2, 3) +pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples WDA') - -pl.subplot(2,2,4) -pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples') +pl.subplot(2, 2, 4) +pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples WDA') +pl.tight_layout() + +pl.show() -- cgit v1.2.3