summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_WDA.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-30 17:01:01 +0200
commitdc8737a30cb6d9f1305173eb8d16fe6716fd1231 (patch)
tree1f03384de2af88ed07a1e850e0871db826ed53e7 /docs/source/auto_examples/plot_WDA.py
parentc2a7a1f3ab4ba5c4f5adeca0fa22d8d6b4fc079d (diff)
wroking make!
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.py')
-rw-r--r--docs/source/auto_examples/plot_WDA.py95
1 files changed, 66 insertions, 29 deletions
diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py
index bbe3888..42789f2 100644
--- a/docs/source/auto_examples/plot_WDA.py
+++ b/docs/source/auto_examples/plot_WDA.py
@@ -4,60 +4,97 @@
Wasserstein Discriminant Analysis
=================================
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
-import ot
-from ot.datasets import get_1D_gauss as gauss
-from ot.dr import wda
+
+from ot.dr import wda, fda
#%% parameters
-n=1000 # nb samples in source and target datasets
-nz=0.2
-xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
-xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+n = 1000 # nb samples in source and target datasets
+nz = 0.2
-nbnoise=8
+# generate circle dataset
+t = np.random.rand(n) * 2 * np.pi
+ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
-xs=np.hstack((xs,np.random.randn(n,nbnoise)))
-xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+t = np.random.rand(n) * 2 * np.pi
+yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
-#%% plot samples
+nbnoise = 8
-pl.figure(1)
+xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+#%% plot samples
+pl.figure(1, figsize=(6.4, 3.5))
-pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+pl.subplot(1, 2, 1)
+pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')
+pl.subplot(1, 2, 2)
+pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Other dimensions')
+pl.tight_layout()
+
+#%% Compute FDA
+p = 2
-#%% plot distributions and loss matrix
-p=2
-reg=1
-k=10
-maxiter=100
+Pfda, projfda = fda(xs, ys, p)
-P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+#%% Compute WDA
+p = 2
+reg = 1e0
+k = 10
+maxiter = 100
+
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
#%% plot samples
-xsp=proj(xs)
-xtp=proj(xt)
+xsp = projfda(xs)
+xtp = projfda(xt)
+
+xspw = projwda(xs)
+xtpw = projwda(xt)
+
+pl.figure(2)
-pl.figure(1,(10,5))
+pl.subplot(2, 2, 1)
+pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples FDA')
-pl.subplot(1,2,1)
-pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 2)
+pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
-pl.title('Projected training samples')
+pl.title('Projected test samples FDA')
+pl.subplot(2, 2, 3)
+pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples WDA')
-pl.subplot(1,2,2)
-pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 4)
+pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
-pl.title('Projected test samples')
+pl.title('Projected test samples WDA')
+pl.tight_layout()
+
+pl.show()