diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-06-09 13:57:07 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-06-09 13:57:07 +0200 |
commit | 05da582675c89ab20998e1a9505bf3c220e296b8 (patch) | |
tree | de91f6ac74e8bab182e5f54a4b9d8b5b0a58991d /docs/source/auto_examples/plot_WDA.rst | |
parent | c7a5e3290527c372aa203c18df5f054409e8a60c (diff) |
update doc
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.rst')
-rw-r--r-- | docs/source/auto_examples/plot_WDA.rst | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst new file mode 100644 index 0000000..379a133 --- /dev/null +++ b/docs/source/auto_examples/plot_WDA.rst @@ -0,0 +1,127 @@ + + +.. _sphx_glr_auto_examples_plot_WDA.py: + + +================================= +WAsserstein Discriminant Analysis +================================= + +@author: rflamary + + + + +.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png + :align: center + + +.. rst-class:: sphx-glr-script-out + + Out:: + + Compiling cost function... + Computing gradient of cost function... + iter cost val grad. norm + 1 +7.5272200933021116e-01 8.85804426e-01 + 2 +2.5764223980223788e-01 3.04501586e-01 + 3 +1.6018169776696620e-01 1.78298483e-01 + 4 +1.4560944642106255e-01 1.42133298e-01 + 5 +1.0243843483991794e-01 1.23342675e-01 + 6 +7.8856617504010643e-02 1.05379766e-01 + 7 +7.7620851864404483e-02 1.04044062e-01 + 8 +7.3160520861018416e-02 8.33770034e-02 + 9 +6.6999294576662857e-02 2.87368977e-02 + 10 +6.6250206928793964e-02 1.72155066e-03 + 11 +6.6247631521353170e-02 2.43806911e-04 + 12 +6.6247596955965438e-02 1.40066459e-04 + 13 +6.6247580176638649e-02 4.77471577e-06 + 14 +6.6247580163923028e-02 3.00484279e-06 + 15 +6.6247580159235792e-02 1.91039983e-06 + 16 +6.6247580156889613e-02 9.56038747e-07 + Terminated - min grad norm reached after 16 iterations, 7.78 seconds. + + + + +| + + +.. code-block:: python + + + import numpy as np + import matplotlib.pylab as pl + import ot + from ot.datasets import get_1D_gauss as gauss + from ot.dr import wda + + + #%% parameters + + n=1000 # nb samples in source and target datasets + nz=0.2 + xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) + xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) + + nbnoise=8 + + xs=np.hstack((xs,np.random.randn(n,nbnoise))) + xt=np.hstack((xt,np.random.randn(n,nbnoise))) + + #%% plot samples + + pl.figure(1) + + + pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') + pl.legend(loc=0) + pl.title('Discriminant dimensions') + + + #%% plot distributions and loss matrix + p=2 + reg=1 + k=10 + maxiter=100 + + P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) + + #%% plot samples + + xsp=proj(xs) + xtp=proj(xt) + + pl.figure(1,(10,5)) + + pl.subplot(1,2,1) + pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') + pl.legend(loc=0) + pl.title('Projected training samples') + + + pl.subplot(1,2,2) + pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') + pl.legend(loc=0) + pl.title('Projected test samples') + +**Total running time of the script:** ( 0 minutes 14.134 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_WDA.py <plot_WDA.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_WDA.ipynb <plot_WDA.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ |