{ "nbformat_minor": 0, "nbformat": 4, "cells": [ { "execution_count": null, "cell_type": "code", "source": [ "%matplotlib inline" ], "outputs": [], "metadata": { "collapsed": false } }, { "source": [ "\n# OT for empirical distributions\n\n\n\n" ], "cell_type": "markdown", "metadata": {} }, { "execution_count": null, "cell_type": "code", "source": [ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=150 # nb bins\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\na,b = ot.unif(n),ot.unif(n)\n# loss matrix\nM=ot.dist(xs,xt)\n#M/=M.max()\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% OT estimation\n\n# EMD\nG0=ot.emd(a,b,M)\n\n# sinkhorn\nlambd=1e-1\nGs=ot.sinkhorn(a,b,M,lambd)\n\n\n# Group lasso regularization\nreg=1e-1\neta=1e0\nGg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)\n\n\n#%% visu matrices\n\npl.figure(3)\n\npl.subplot(2,3,1)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix ')\n\npl.subplot(2,3,2)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix Sinkhorn')\n\npl.subplot(2,3,3)\npl.imshow(Gg,interpolation='nearest')\npl.title('OT matrix Group lasso')\n\npl.subplot(2,3,4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n\npl.subplot(2,3,5)\not.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.subplot(2,3,6)\not.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n#%% sample interpolation\n\nxst0=n*G0.dot(xt)\nxsts=n*Gs.dot(xt)\nxstg=n*Gg.dot(xt)\n\npl.figure(4)\npl.subplot(2,3,1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,3,2)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,3,3)\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Grouplasso')" ], "outputs": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 2", "name": "python2", "language": "python" }, "language_info": { "mimetype": "text/x-python", "nbconvert_exporter": "python", "name": "python", "file_extension": ".py", "version": "2.7.12", "pygments_lexer": "ipython2", "codemirror_mode": { "version": 2, "name": "ipython" } } } }