summaryrefslogtreecommitdiff
path: root/examples/plot_WDA.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-12 23:16:00 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commit6d7fd7e9faffa777cef222bdfc48c7ad732ab950 (patch)
tree89a6cbc4f52cac2e0ef1057943c1f8d0f2d52723 /examples/plot_WDA.py
parent7ad472500dcce284231fc5968effa755802ab4ea (diff)
more
Diffstat (limited to 'examples/plot_WDA.py')
-rw-r--r--examples/plot_WDA.py86
1 files changed, 44 insertions, 42 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
index 5fa2ab1..8a44022 100644
--- a/examples/plot_WDA.py
+++ b/examples/plot_WDA.py
@@ -16,81 +16,83 @@ from ot.dr import wda, fda
#%% parameters
-n=1000 # nb samples in source and target datasets
-nz=0.2
+n = 1000 # nb samples in source and target datasets
+nz = 0.2
# generate circle dataset
-t=np.random.rand(n)*2*np.pi
-ys=np.floor((np.arange(n)*1.0/n*3))+1
-xs=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
-xs=xs*ys.reshape(-1,1)+nz*np.random.randn(n,2)
+t = np.random.rand(n) * 2 * np.pi
+ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
-t=np.random.rand(n)*2*np.pi
-yt=np.floor((np.arange(n)*1.0/n*3))+1
-xt=np.concatenate((np.cos(t).reshape((-1,1)),np.sin(t).reshape((-1,1))),1)
-xt=xt*yt.reshape(-1,1)+nz*np.random.randn(n,2)
+t = np.random.rand(n) * 2 * np.pi
+yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
-nbnoise=8
+nbnoise = 8
-xs=np.hstack((xs,np.random.randn(n,nbnoise)))
-xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+xt = np.hstack((xt, np.random.randn(n, nbnoise)))
#%% plot samples
-pl.figure(1,(10,5))
+pl.figure(1, figsize=(6.4, 3.5))
-pl.subplot(1,2,1)
-pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+pl.subplot(1, 2, 1)
+pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')
-pl.subplot(1,2,2)
-pl.scatter(xt[:,2],xt[:,3],c=ys,marker='+',label='Source samples')
+pl.subplot(1, 2, 2)
+pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Other dimensions')
-pl.show()
+pl.tight_layout()
#%% Compute FDA
-p=2
+p = 2
-Pfda,projfda = fda(xs,ys,p)
+Pfda, projfda = fda(xs, ys, p)
#%% Compute WDA
-p=2
-reg=1e0
-k=10
-maxiter=100
+p = 2
+reg = 1e0
+k = 10
+maxiter = 100
-Pwda,projwda = wda(xs,ys,p,reg,k,maxiter=maxiter)
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
#%% plot samples
-xsp=projfda(xs)
-xtp=projfda(xt)
+xsp = projfda(xs)
+xtp = projfda(xt)
-xspw=projwda(xs)
-xtpw=projwda(xt)
+xspw = projwda(xs)
+xtpw = projwda(xt)
-pl.figure(1,(10,10))
+pl.figure(2)
-pl.subplot(2,2,1)
-pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 1)
+pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected training samples FDA')
-
-pl.subplot(2,2,2)
-pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 2)
+pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected test samples FDA')
-
-pl.subplot(2,2,3)
-pl.scatter(xspw[:,0],xspw[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 3)
+pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected training samples WDA')
-
-pl.subplot(2,2,4)
-pl.scatter(xtpw[:,0],xtpw[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 4)
+pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
pl.title('Projected test samples WDA')
+pl.tight_layout()
+
+pl.show()