{ "nbformat_minor": 0, "nbformat": 4, "cells": [ { "execution_count": null, "cell_type": "code", "source": [ "%matplotlib inline" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "\n# Wasserstein Discriminant Analysis\n\n\nThis example illustrate the use of WDA as proposed in [11].\n\n\n[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).\nWasserstein Discriminant Analysis.\n\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\n\nfrom ot.dr import wda, fda" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "Generate data\n-------------\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "#%% parameters\n\nn = 1000 # nb samples in source and target datasets\nnz = 0.2\n\n# generate circle dataset\nt = np.random.rand(n) * 2 * np.pi\nys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxs = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nt = np.random.rand(n) * 2 * np.pi\nyt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxt = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nnbnoise = 8\n\nxs = np.hstack((xs, np.random.randn(n, nbnoise)))\nxt = np.hstack((xt, np.random.randn(n, nbnoise)))" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "Plot data\n---------\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "#%% plot samples\npl.figure(1, figsize=(6.4, 3.5))\n\npl.subplot(1, 2, 1)\npl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\npl.subplot(1, 2, 2)\npl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Other dimensions')\npl.tight_layout()" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "Compute Fisher Discriminant Analysis\n------------------------------------\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "#%% Compute FDA\np = 2\n\nPfda, projfda = fda(xs, ys, p)" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "Compute Wasserstein Discriminant Analysis\n-----------------------------------------\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "#%% Compute WDA\np = 2\nreg = 1e0\nk = 10\nmaxiter = 100\n\nPwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "Plot 2D projections\n-------------------\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "#%% plot samples\n\nxsp = projfda(xs)\nxtp = projfda(xt)\n\nxspw = projwda(xs)\nxtpw = projwda(xt)\n\npl.figure(2)\n\npl.subplot(2, 2, 1)\npl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples FDA')\n\npl.subplot(2, 2, 2)\npl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples FDA')\n\npl.subplot(2, 2, 3)\npl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples WDA')\n\npl.subplot(2, 2, 4)\npl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples WDA')\npl.tight_layout()\n\npl.show()" ], "outputs": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 2", "name": "python2", "language": "python" }, "language_info": { "mimetype": "text/x-python", "nbconvert_exporter": "python", "name": "python", "file_extension": ".py", "version": "2.7.12", "pygments_lexer": "ipython2", "codemirror_mode": { "version": 2, "name": "ipython" } } } }