summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_conv.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_OT_conv.ipynb')
-rw-r--r--docs/source/auto_examples/plot_OT_conv.ipynb54
1 files changed, 0 insertions, 54 deletions
diff --git a/docs/source/auto_examples/plot_OT_conv.ipynb b/docs/source/auto_examples/plot_OT_conv.ipynb
deleted file mode 100644
index 7fc4af0..0000000
--- a/docs/source/auto_examples/plot_OT_conv.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nimport scipy as sp\nimport scipy.signal as sps\n#%% parameters\n\nn=10 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nxx,yy=np.meshgrid(x,x)\n\n\nxpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))\n\nM=ot.dist(xpos)\n\n\nI0=((xx-5)**2+(yy-5)**2<3**2)*1.0\nI1=((xx-7)**2+(yy-7)**2<3**2)*1.0\n\nI0/=I0.sum()\nI1/=I1.sum()\n\ni0=I0.ravel()\ni1=I1.ravel()\n\nM=M[i0>0,:][:,i1>0].copy()\ni0=i0[i0>0]\ni1=i1[i1>0]\nItot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)\n\n\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I1)\n\n\n#%% barycenter computation\n\nalpha=0.5 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n\ndef conv2(I,k):\n return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)\n\ndef conv2n(I,k):\n res=np.zeros_like(I)\n for i in range(I.shape[2]):\n res[:,:,i]=conv2(I[:,:,i],k)\n return res\n\n\ndef get_1Dkernel(reg,thr=1e-16,wmax=1024):\n w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)\n x=np.arange(w,dtype=np.float64)\n return np.exp(-((x-w/2)**2)/reg)\n \nthr=1e-16\nreg=1e0\n\nk=get_1Dkernel(reg)\npl.figure(2)\npl.plot(k)\n\nI05=conv2(I0,k)\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I05)\n\n#%%\n\nG=ot.emd(i0,i1,M)\nr0=np.sum(M*G)\n\nreg=1e-1\nGs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)\nrs=np.sum(M*Gs)\n\n#%%\n\ndef mylog(u):\n tmp=np.log(u)\n tmp[np.isnan(tmp)]=0\n return tmp\n\ndef sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):\n\n\n a=np.asarray(a,dtype=np.float64)\n b=np.asarray(b,dtype=np.float64)\n \n \n if len(b.shape)>2:\n nbb=b.shape[2]\n a=a[:,:,np.newaxis]\n else:\n nbb=0\n \n\n if log:\n log={'err':[]}\n\n # we assume that no distances are null except those of the diagonal of distances\n if nbb:\n u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n else:\n u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n \n \n k=get_1Dkernel(reg)\n \n if nbb:\n K=lambda I: conv2n(I,k)\n else:\n K=lambda I: conv2(I,k)\n\n cpt = 0\n err=1\n while (err>stopThr and cpt<numItermax):\n uprev = u\n vprev = v\n \n v = np.divide(b, K(u))\n u = np.divide(a, K(v))\n\n if (np.any(np.isnan(u)) or np.any(np.isnan(v)) \n or np.any(np.isinf(u)) or np.any(np.isinf(v))):\n # we have reached the machine precision\n # come back to previous solution and quit loop\n print('Warning: numerical errors at iteration', cpt)\n u = uprev\n v = vprev\n break\n if cpt%10==0:\n # we can speed up the process by checking for the error only all the 10th iterations\n\n err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)\n\n if log:\n log['err'].append(err)\n\n if verbose:\n if cpt%200 ==0:\n print('{:5s}|{:12s}'.format('It.','Err')+'\\n'+'-'*19)\n print('{:5d}|{:8e}|'.format(cpt,err))\n cpt = cpt +1\n if log:\n log['u']=u\n log['v']=v\n \n if nbb: #return only loss \n res=np.zeros((nbb))\n for i in range(nbb):\n res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)\n if log:\n return res,log\n else:\n return res \n \n else: # return OT matrix\n res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))\n if log:\n \n return res,log\n else:\n return res\n\nreg=1e0\nr,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)\na=I0\nb=I1\nu=log['u']\nv=log['v']\n#%% barycenter interpolation"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file