diff options
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.rst')
-rw-r--r-- | docs/source/auto_examples/plot_WDA.rst | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst new file mode 100644 index 0000000..2d83123 --- /dev/null +++ b/docs/source/auto_examples/plot_WDA.rst @@ -0,0 +1,244 @@ + + +.. _sphx_glr_auto_examples_plot_WDA.py: + + +================================= +Wasserstein Discriminant Analysis +================================= + +This example illustrate the use of WDA as proposed in [11]. + + +[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). +Wasserstein Discriminant Analysis. + + + + +.. code-block:: python + + + # Author: Remi Flamary <remi.flamary@unice.fr> + # + # License: MIT License + + import numpy as np + import matplotlib.pylab as pl + + from ot.dr import wda, fda + + + + + + + + +Generate data +------------- + + + +.. code-block:: python + + + #%% parameters + + n = 1000 # nb samples in source and target datasets + nz = 0.2 + + # generate circle dataset + t = np.random.rand(n) * 2 * np.pi + ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + xs = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) + xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) + + t = np.random.rand(n) * 2 * np.pi + yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 + xt = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) + xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) + + nbnoise = 8 + + xs = np.hstack((xs, np.random.randn(n, nbnoise))) + xt = np.hstack((xt, np.random.randn(n, nbnoise))) + + + + + + + +Plot data +--------- + + + +.. code-block:: python + + + #%% plot samples + pl.figure(1, figsize=(6.4, 3.5)) + + pl.subplot(1, 2, 1) + pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') + pl.legend(loc=0) + pl.title('Discriminant dimensions') + + pl.subplot(1, 2, 2) + pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') + pl.legend(loc=0) + pl.title('Other dimensions') + pl.tight_layout() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png + :align: center + + + + +Compute Fisher Discriminant Analysis +------------------------------------ + + + +.. code-block:: python + + + #%% Compute FDA + p = 2 + + Pfda, projfda = fda(xs, ys, p) + + + + + + + +Compute Wasserstein Discriminant Analysis +----------------------------------------- + + + +.. code-block:: python + + + #%% Compute WDA + p = 2 + reg = 1e0 + k = 10 + maxiter = 100 + + Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) + + + + + + +.. rst-class:: sphx-glr-script-out + + Out:: + + Compiling cost function... + Computing gradient of cost function... + iter cost val grad. norm + 1 +9.0167295050534191e-01 2.28422652e-01 + 2 +4.8324990550878105e-01 4.89362707e-01 + 3 +3.4613154515357075e-01 2.84117562e-01 + 4 +2.5277108387195002e-01 1.24888750e-01 + 5 +2.4113858393736629e-01 8.07491482e-02 + 6 +2.3642108593032782e-01 1.67612140e-02 + 7 +2.3625721372202199e-01 7.68640008e-03 + 8 +2.3625461994913738e-01 7.42200784e-03 + 9 +2.3624493441436939e-01 6.43534105e-03 + 10 +2.3621901383686217e-01 2.17960585e-03 + 11 +2.3621854258326572e-01 2.03306749e-03 + 12 +2.3621696458678049e-01 1.37118721e-03 + 13 +2.3621569489873540e-01 2.76368907e-04 + 14 +2.3621565599232983e-01 1.41898134e-04 + 15 +2.3621564465487518e-01 5.96602069e-05 + 16 +2.3621564232556647e-01 1.08709521e-05 + 17 +2.3621564230277003e-01 9.17855656e-06 + 18 +2.3621564224857586e-01 1.73728345e-06 + 19 +2.3621564224748123e-01 1.17770019e-06 + 20 +2.3621564224658587e-01 2.16179383e-07 + Terminated - min grad norm reached after 20 iterations, 9.20 seconds. + + +Plot 2D projections +------------------- + + + +.. code-block:: python + + + #%% plot samples + + xsp = projfda(xs) + xtp = projfda(xt) + + xspw = projwda(xs) + xtpw = projwda(xt) + + pl.figure(2) + + pl.subplot(2, 2, 1) + pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') + pl.legend(loc=0) + pl.title('Projected training samples FDA') + + pl.subplot(2, 2, 2) + pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') + pl.legend(loc=0) + pl.title('Projected test samples FDA') + + pl.subplot(2, 2, 3) + pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') + pl.legend(loc=0) + pl.title('Projected training samples WDA') + + pl.subplot(2, 2, 4) + pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') + pl.legend(loc=0) + pl.title('Projected test samples WDA') + pl.tight_layout() + + pl.show() + + + +.. image:: /auto_examples/images/sphx_glr_plot_WDA_003.png + :align: center + + + + +**Total running time of the script:** ( 0 minutes 16.182 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_WDA.py <plot_WDA.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_WDA.ipynb <plot_WDA.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ |