summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_WDA.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.rst')
-rw-r--r--docs/source/auto_examples/plot_WDA.rst127
1 files changed, 127 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst
new file mode 100644
index 0000000..379a133
--- /dev/null
+++ b/docs/source/auto_examples/plot_WDA.rst
@@ -0,0 +1,127 @@
+
+
+.. _sphx_glr_auto_examples_plot_WDA.py:
+
+
+=================================
+WAsserstein Discriminant Analysis
+=================================
+
+@author: rflamary
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Compiling cost function...
+ Computing gradient of cost function...
+ iter cost val grad. norm
+ 1 +7.5272200933021116e-01 8.85804426e-01
+ 2 +2.5764223980223788e-01 3.04501586e-01
+ 3 +1.6018169776696620e-01 1.78298483e-01
+ 4 +1.4560944642106255e-01 1.42133298e-01
+ 5 +1.0243843483991794e-01 1.23342675e-01
+ 6 +7.8856617504010643e-02 1.05379766e-01
+ 7 +7.7620851864404483e-02 1.04044062e-01
+ 8 +7.3160520861018416e-02 8.33770034e-02
+ 9 +6.6999294576662857e-02 2.87368977e-02
+ 10 +6.6250206928793964e-02 1.72155066e-03
+ 11 +6.6247631521353170e-02 2.43806911e-04
+ 12 +6.6247596955965438e-02 1.40066459e-04
+ 13 +6.6247580176638649e-02 4.77471577e-06
+ 14 +6.6247580163923028e-02 3.00484279e-06
+ 15 +6.6247580159235792e-02 1.91039983e-06
+ 16 +6.6247580156889613e-02 9.56038747e-07
+ Terminated - min grad norm reached after 16 iterations, 7.78 seconds.
+
+
+
+
+|
+
+
+.. code-block:: python
+
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from ot.datasets import get_1D_gauss as gauss
+ from ot.dr import wda
+
+
+ #%% parameters
+
+ n=1000 # nb samples in source and target datasets
+ nz=0.2
+ xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
+ xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+
+ nbnoise=8
+
+ xs=np.hstack((xs,np.random.randn(n,nbnoise)))
+ xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+
+ #%% plot samples
+
+ pl.figure(1)
+
+
+ pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Discriminant dimensions')
+
+
+ #%% plot distributions and loss matrix
+ p=2
+ reg=1
+ k=10
+ maxiter=100
+
+ P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+
+ #%% plot samples
+
+ xsp=proj(xs)
+ xtp=proj(xt)
+
+ pl.figure(1,(10,5))
+
+ pl.subplot(1,2,1)
+ pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected training samples')
+
+
+ pl.subplot(1,2,2)
+ pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected test samples')
+
+**Total running time of the script:** ( 0 minutes 14.134 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_WDA.py <plot_WDA.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_WDA.ipynb <plot_WDA.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_