diff options
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.py')
-rw-r--r-- | docs/source/auto_examples/plot_WDA.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py new file mode 100644 index 0000000..94b7ef4 --- /dev/null +++ b/docs/source/auto_examples/plot_WDA.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +""" +================================= +WAsserstein Discriminant Analysis +================================= + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.datasets import get_1D_gauss as gauss +from ot.dr import wda + + +#%% parameters + +n=1000 # nb samples in source and target datasets +nz=0.2 +xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) +xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) + +nbnoise=8 + +xs=np.hstack((xs,np.random.randn(n,nbnoise))) +xt=np.hstack((xt,np.random.randn(n,nbnoise))) + +#%% plot samples + +pl.figure(1) + + +pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') +pl.legend(loc=0) +pl.title('Discriminant dimensions') + + +#%% plot distributions and loss matrix +p=2 +reg=1 +k=10 +maxiter=100 + +P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) + +#%% plot samples + +xsp=proj(xs) +xtp=proj(xt) + +pl.figure(1,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') +pl.legend(loc=0) +pl.title('Projected training samples') + + +pl.subplot(1,2,2) +pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') +pl.legend(loc=0) +pl.title('Projected test samples') |