summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_conv.ipynb
blob: 7fc4af05e71cbbed769ef715e2c3135881484a3a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
{
  "nbformat_minor": 0, 
  "nbformat": 4, 
  "cells": [
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "%matplotlib inline"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nimport scipy as sp\nimport scipy.signal as sps\n#%% parameters\n\nn=10 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nxx,yy=np.meshgrid(x,x)\n\n\nxpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))\n\nM=ot.dist(xpos)\n\n\nI0=((xx-5)**2+(yy-5)**2<3**2)*1.0\nI1=((xx-7)**2+(yy-7)**2<3**2)*1.0\n\nI0/=I0.sum()\nI1/=I1.sum()\n\ni0=I0.ravel()\ni1=I1.ravel()\n\nM=M[i0>0,:][:,i1>0].copy()\ni0=i0[i0>0]\ni1=i1[i1>0]\nItot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)\n\n\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I1)\n\n\n#%% barycenter computation\n\nalpha=0.5 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n\ndef conv2(I,k):\n    return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)\n\ndef conv2n(I,k):\n    res=np.zeros_like(I)\n    for i in range(I.shape[2]):\n        res[:,:,i]=conv2(I[:,:,i],k)\n    return res\n\n\ndef get_1Dkernel(reg,thr=1e-16,wmax=1024):\n    w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)\n    x=np.arange(w,dtype=np.float64)\n    return np.exp(-((x-w/2)**2)/reg)\n    \nthr=1e-16\nreg=1e0\n\nk=get_1Dkernel(reg)\npl.figure(2)\npl.plot(k)\n\nI05=conv2(I0,k)\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I05)\n\n#%%\n\nG=ot.emd(i0,i1,M)\nr0=np.sum(M*G)\n\nreg=1e-1\nGs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)\nrs=np.sum(M*Gs)\n\n#%%\n\ndef mylog(u):\n    tmp=np.log(u)\n    tmp[np.isnan(tmp)]=0\n    return tmp\n\ndef sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):\n\n\n    a=np.asarray(a,dtype=np.float64)\n    b=np.asarray(b,dtype=np.float64)\n        \n    \n    if len(b.shape)>2:\n        nbb=b.shape[2]\n        a=a[:,:,np.newaxis]\n    else:\n        nbb=0\n    \n\n    if log:\n        log={'err':[]}\n\n    # we assume that no distances are null except those of the diagonal of distances\n    if nbb:\n        u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))\n        v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))\n        a0=1.0/(np.prod(b.shape[:2]))\n    else:\n        u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))\n        v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))\n        a0=1.0/(np.prod(b.shape[:2]))\n        \n        \n    k=get_1Dkernel(reg)\n    \n    if nbb:\n        K=lambda I: conv2n(I,k)\n    else:\n        K=lambda I: conv2(I,k)\n\n    cpt = 0\n    err=1\n    while (err>stopThr and cpt<numItermax):\n        uprev = u\n        vprev = v\n        \n        v = np.divide(b, K(u))\n        u = np.divide(a, K(v))\n\n        if (np.any(np.isnan(u)) or np.any(np.isnan(v)) \n            or np.any(np.isinf(u)) or np.any(np.isinf(v))):\n            # we have reached the machine precision\n            # come back to previous solution and quit loop\n            print('Warning: numerical errors at iteration', cpt)\n            u = uprev\n            v = vprev\n            break\n        if cpt%10==0:\n            # we can speed up the process by checking for the error only all the 10th iterations\n\n            err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)\n\n            if log:\n                log['err'].append(err)\n\n            if verbose:\n                if cpt%200 ==0:\n                    print('{:5s}|{:12s}'.format('It.','Err')+'\\n'+'-'*19)\n                print('{:5d}|{:8e}|'.format(cpt,err))\n        cpt = cpt +1\n    if log:\n        log['u']=u\n        log['v']=v\n        \n    if nbb: #return only loss \n        res=np.zeros((nbb))\n        for i in range(nbb):\n            res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)\n        if log:\n            return res,log\n        else:\n            return res        \n        \n    else: # return OT matrix\n        res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))\n        if log:\n            \n            return res,log\n        else:\n            return res\n\nreg=1e0\nr,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)\na=I0\nb=I1\nu=log['u']\nv=log['v']\n#%% barycenter interpolation"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }
  ], 
  "metadata": {
    "kernelspec": {
      "display_name": "Python 2", 
      "name": "python2", 
      "language": "python"
    }, 
    "language_info": {
      "mimetype": "text/x-python", 
      "nbconvert_exporter": "python", 
      "name": "python", 
      "file_extension": ".py", 
      "version": "2.7.12", 
      "pygments_lexer": "ipython2", 
      "codemirror_mode": {
        "version": 2, 
        "name": "ipython"
      }
    }
  }
}