diff options
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.ipynb')
-rw-r--r-- | docs/source/auto_examples/plot_WDA.ipynb | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/docs/source/auto_examples/plot_WDA.ipynb b/docs/source/auto_examples/plot_WDA.ipynb index 408a605..5568128 100644 --- a/docs/source/auto_examples/plot_WDA.ipynb +++ b/docs/source/auto_examples/plot_WDA.ipynb @@ -15,7 +15,7 @@ }, { "source": [ - "\n# Wasserstein Discriminant Analysis\n\n\n@author: rflamary\n\n" + "\n# Wasserstein Discriminant Analysis\n\n\n\n" ], "cell_type": "markdown", "metadata": {} @@ -24,7 +24,7 @@ "execution_count": null, "cell_type": "code", "source": [ - "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\nfrom ot.dr import wda\n\n\n#%% parameters\n\nn=1000 # nb samples in source and target datasets\nnz=0.2\nxs,ys=ot.datasets.get_data_classif('3gauss',n,nz)\nxt,yt=ot.datasets.get_data_classif('3gauss',n,nz)\n\nnbnoise=8\n\nxs=np.hstack((xs,np.random.randn(n,nbnoise)))\nxt=np.hstack((xt,np.random.randn(n,nbnoise)))\n\n#%% plot samples\n\npl.figure(1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\n\n#%% plot distributions and loss matrix\np=2\nreg=1\nk=10\nmaxiter=100\n\nP,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)\n\n#%% plot samples\n\nxsp=proj(xs)\nxtp=proj(xt)\n\npl.figure(1,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples')\n\n\npl.subplot(1,2,2)\npl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples')" + "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\n\nfrom ot.dr import wda, fda\n\n\n#%% parameters\n\nn = 1000 # nb samples in source and target datasets\nnz = 0.2\n\n# generate circle dataset\nt = np.random.rand(n) * 2 * np.pi\nys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxs = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nt = np.random.rand(n) * 2 * np.pi\nyt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxt = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nnbnoise = 8\n\nxs = np.hstack((xs, np.random.randn(n, nbnoise)))\nxt = np.hstack((xt, np.random.randn(n, nbnoise)))\n\n#%% plot samples\npl.figure(1, figsize=(6.4, 3.5))\n\npl.subplot(1, 2, 1)\npl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\npl.subplot(1, 2, 2)\npl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Other dimensions')\npl.tight_layout()\n\n#%% Compute FDA\np = 2\n\nPfda, projfda = fda(xs, ys, p)\n\n#%% Compute WDA\np = 2\nreg = 1e0\nk = 10\nmaxiter = 100\n\nPwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)\n\n#%% plot samples\n\nxsp = projfda(xs)\nxtp = projfda(xt)\n\nxspw = projwda(xs)\nxtpw = projwda(xt)\n\npl.figure(2)\n\npl.subplot(2, 2, 1)\npl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples FDA')\n\npl.subplot(2, 2, 2)\npl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples FDA')\n\npl.subplot(2, 2, 3)\npl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples WDA')\n\npl.subplot(2, 2, 4)\npl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples WDA')\npl.tight_layout()\n\npl.show()" ], "outputs": [], "metadata": { |