summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_WDA.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.rst')
-rw-r--r--docs/source/auto_examples/plot_WDA.rst244
1 files changed, 0 insertions, 244 deletions
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst
deleted file mode 100644
index 2d83123..0000000
--- a/docs/source/auto_examples/plot_WDA.rst
+++ /dev/null
@@ -1,244 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_WDA.py:
-
-
-=================================
-Wasserstein Discriminant Analysis
-=================================
-
-This example illustrate the use of WDA as proposed in [11].
-
-
-[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
-Wasserstein Discriminant Analysis.
-
-
-
-
-.. code-block:: python
-
-
- # Author: Remi Flamary <remi.flamary@unice.fr>
- #
- # License: MIT License
-
- import numpy as np
- import matplotlib.pylab as pl
-
- from ot.dr import wda, fda
-
-
-
-
-
-
-
-
-Generate data
--------------
-
-
-
-.. code-block:: python
-
-
- #%% parameters
-
- n = 1000 # nb samples in source and target datasets
- nz = 0.2
-
- # generate circle dataset
- t = np.random.rand(n) * 2 * np.pi
- ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
- xs = np.concatenate(
- (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
- xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
-
- t = np.random.rand(n) * 2 * np.pi
- yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
- xt = np.concatenate(
- (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
- xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
-
- nbnoise = 8
-
- xs = np.hstack((xs, np.random.randn(n, nbnoise)))
- xt = np.hstack((xt, np.random.randn(n, nbnoise)))
-
-
-
-
-
-
-
-Plot data
----------
-
-
-
-.. code-block:: python
-
-
- #%% plot samples
- pl.figure(1, figsize=(6.4, 3.5))
-
- pl.subplot(1, 2, 1)
- pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
- pl.legend(loc=0)
- pl.title('Discriminant dimensions')
-
- pl.subplot(1, 2, 2)
- pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
- pl.legend(loc=0)
- pl.title('Other dimensions')
- pl.tight_layout()
-
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
- :align: center
-
-
-
-
-Compute Fisher Discriminant Analysis
-------------------------------------
-
-
-
-.. code-block:: python
-
-
- #%% Compute FDA
- p = 2
-
- Pfda, projfda = fda(xs, ys, p)
-
-
-
-
-
-
-
-Compute Wasserstein Discriminant Analysis
------------------------------------------
-
-
-
-.. code-block:: python
-
-
- #%% Compute WDA
- p = 2
- reg = 1e0
- k = 10
- maxiter = 100
-
- Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
-
-
-
-
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out::
-
- Compiling cost function...
- Computing gradient of cost function...
- iter cost val grad. norm
- 1 +9.0167295050534191e-01 2.28422652e-01
- 2 +4.8324990550878105e-01 4.89362707e-01
- 3 +3.4613154515357075e-01 2.84117562e-01
- 4 +2.5277108387195002e-01 1.24888750e-01
- 5 +2.4113858393736629e-01 8.07491482e-02
- 6 +2.3642108593032782e-01 1.67612140e-02
- 7 +2.3625721372202199e-01 7.68640008e-03
- 8 +2.3625461994913738e-01 7.42200784e-03
- 9 +2.3624493441436939e-01 6.43534105e-03
- 10 +2.3621901383686217e-01 2.17960585e-03
- 11 +2.3621854258326572e-01 2.03306749e-03
- 12 +2.3621696458678049e-01 1.37118721e-03
- 13 +2.3621569489873540e-01 2.76368907e-04
- 14 +2.3621565599232983e-01 1.41898134e-04
- 15 +2.3621564465487518e-01 5.96602069e-05
- 16 +2.3621564232556647e-01 1.08709521e-05
- 17 +2.3621564230277003e-01 9.17855656e-06
- 18 +2.3621564224857586e-01 1.73728345e-06
- 19 +2.3621564224748123e-01 1.17770019e-06
- 20 +2.3621564224658587e-01 2.16179383e-07
- Terminated - min grad norm reached after 20 iterations, 9.20 seconds.
-
-
-Plot 2D projections
--------------------
-
-
-
-.. code-block:: python
-
-
- #%% plot samples
-
- xsp = projfda(xs)
- xtp = projfda(xt)
-
- xspw = projwda(xs)
- xtpw = projwda(xt)
-
- pl.figure(2)
-
- pl.subplot(2, 2, 1)
- pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
- pl.legend(loc=0)
- pl.title('Projected training samples FDA')
-
- pl.subplot(2, 2, 2)
- pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
- pl.legend(loc=0)
- pl.title('Projected test samples FDA')
-
- pl.subplot(2, 2, 3)
- pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
- pl.legend(loc=0)
- pl.title('Projected training samples WDA')
-
- pl.subplot(2, 2, 4)
- pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
- pl.legend(loc=0)
- pl.title('Projected test samples WDA')
- pl.tight_layout()
-
- pl.show()
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_WDA_003.png
- :align: center
-
-
-
-
-**Total running time of the script:** ( 0 minutes 16.182 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_WDA.py <plot_WDA.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_WDA.ipynb <plot_WDA.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_