summaryrefslogtreecommitdiff
path: root/examples/plot_WDA.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:00:10 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:00:10 +0200
commit140baad30ab822deeccd6f1cb4fedc3136370ab4 (patch)
treed93212407588cdfa6b6807a82cc7688636b7d10c /examples/plot_WDA.py
parentb30a380fe5a5115cd0a5596f08903259c077f12c (diff)
add WDA
Diffstat (limited to 'examples/plot_WDA.py')
-rw-r--r--examples/plot_WDA.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
new file mode 100644
index 0000000..8d24bdc
--- /dev/null
+++ b/examples/plot_WDA.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+"""
+====================
+1D optimal transport
+====================
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from ot.datasets import get_1D_gauss as gauss
+from ot.dr import wda
+
+
+#%% parameters
+
+n=1000 # nb samples in source and target datasets
+nz=0.2
+xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
+xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+
+nbnoise=8
+
+xs=np.hstack((xs,np.random.randn(n,nbnoise)))
+xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+
+#%% plot samples
+
+pl.figure(1)
+
+
+pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+pl.legend(loc=0)
+pl.title('Discriminant dimensions')
+
+
+#%% plot distributions and loss matrix
+p=2
+reg=1
+k=10
+maxiter=100
+
+P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+
+#%% plot samples
+
+xsp=proj(xs)
+xtp=proj(xt)
+
+pl.figure(1,(10,5))
+
+pl.subplot(1,2,1)
+pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples')
+
+
+pl.subplot(1,2,2)
+pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples')