.. _sphx_glr_auto_examples_plot_WDA.py: ================================= Wasserstein Discriminant Analysis ================================= This example illustrate the use of WDA as proposed in [11]. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. .. code-block:: python # Author: Remi Flamary # # License: MIT License import numpy as np import matplotlib.pylab as pl from ot.dr import wda, fda Generate data ------------- .. code-block:: python #%% parameters n = 1000 # nb samples in source and target datasets nz = 0.2 # generate circle dataset t = np.random.rand(n) * 2 * np.pi ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 xs = np.concatenate( (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) t = np.random.rand(n) * 2 * np.pi yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 xt = np.concatenate( (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) nbnoise = 8 xs = np.hstack((xs, np.random.randn(n, nbnoise))) xt = np.hstack((xt, np.random.randn(n, nbnoise))) Plot data --------- .. code-block:: python #%% plot samples pl.figure(1, figsize=(6.4, 3.5)) pl.subplot(1, 2, 1) pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') pl.subplot(1, 2, 2) pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Other dimensions') pl.tight_layout() .. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png :align: center Compute Fisher Discriminant Analysis ------------------------------------ .. code-block:: python #%% Compute FDA p = 2 Pfda, projfda = fda(xs, ys, p) Compute Wasserstein Discriminant Analysis ----------------------------------------- .. code-block:: python #%% Compute WDA p = 2 reg = 1e0 k = 10 maxiter = 100 Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) .. rst-class:: sphx-glr-script-out Out:: Compiling cost function... Computing gradient of cost function... iter cost val grad. norm 1 +9.0167295050534191e-01 2.28422652e-01 2 +4.8324990550878105e-01 4.89362707e-01 3 +3.4613154515357075e-01 2.84117562e-01 4 +2.5277108387195002e-01 1.24888750e-01 5 +2.4113858393736629e-01 8.07491482e-02 6 +2.3642108593032782e-01 1.67612140e-02 7 +2.3625721372202199e-01 7.68640008e-03 8 +2.3625461994913738e-01 7.42200784e-03 9 +2.3624493441436939e-01 6.43534105e-03 10 +2.3621901383686217e-01 2.17960585e-03 11 +2.3621854258326572e-01 2.03306749e-03 12 +2.3621696458678049e-01 1.37118721e-03 13 +2.3621569489873540e-01 2.76368907e-04 14 +2.3621565599232983e-01 1.41898134e-04 15 +2.3621564465487518e-01 5.96602069e-05 16 +2.3621564232556647e-01 1.08709521e-05 17 +2.3621564230277003e-01 9.17855656e-06 18 +2.3621564224857586e-01 1.73728345e-06 19 +2.3621564224748123e-01 1.17770019e-06 20 +2.3621564224658587e-01 2.16179383e-07 Terminated - min grad norm reached after 20 iterations, 9.20 seconds. Plot 2D projections ------------------- .. code-block:: python #%% plot samples xsp = projfda(xs) xtp = projfda(xt) xspw = projwda(xs) xtpw = projwda(xt) pl.figure(2) pl.subplot(2, 2, 1) pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples FDA') pl.subplot(2, 2, 2) pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples FDA') pl.subplot(2, 2, 3) pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples WDA') pl.subplot(2, 2, 4) pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples WDA') pl.tight_layout() pl.show() .. image:: /auto_examples/images/sphx_glr_plot_WDA_003.png :align: center **Total running time of the script:** ( 0 minutes 16.182 seconds) .. container:: sphx-glr-footer .. container:: sphx-glr-download :download:`Download Python source code: plot_WDA.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_WDA.ipynb ` .. rst-class:: sphx-glr-signature `Generated by Sphinx-Gallery `_