summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OTDA_2D.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_2D.ipynb')
-rw-r--r--docs/source/auto_examples/plot_OTDA_2D.ipynb54
1 files changed, 0 insertions, 54 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_2D.ipynb b/docs/source/auto_examples/plot_OTDA_2D.ipynb
deleted file mode 100644
index 2ffb256..0000000
--- a/docs/source/auto_examples/plot_OTDA_2D.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n# OT for empirical distributions\n\n\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=150 # nb bins\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\na,b = ot.unif(n),ot.unif(n)\n# loss matrix\nM=ot.dist(xs,xt)\n#M/=M.max()\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% OT estimation\n\n# EMD\nG0=ot.emd(a,b,M)\n\n# sinkhorn\nlambd=1e-1\nGs=ot.sinkhorn(a,b,M,lambd)\n\n\n# Group lasso regularization\nreg=1e-1\neta=1e0\nGg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)\n\n\n#%% visu matrices\n\npl.figure(3)\n\npl.subplot(2,3,1)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix ')\n\npl.subplot(2,3,2)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix Sinkhorn')\n\npl.subplot(2,3,3)\npl.imshow(Gg,interpolation='nearest')\npl.title('OT matrix Group lasso')\n\npl.subplot(2,3,4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n\npl.subplot(2,3,5)\not.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.subplot(2,3,6)\not.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n#%% sample interpolation\n\nxst0=n*G0.dot(xt)\nxsts=n*Gs.dot(xt)\nxstg=n*Gg.dot(xt)\n\npl.figure(4)\npl.subplot(2,3,1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,3,2)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,3,3)\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Grouplasso')"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file