summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_conv.py
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_OT_conv.py')
-rw-r--r--docs/source/auto_examples/plot_OT_conv.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_OT_conv.py b/docs/source/auto_examples/plot_OT_conv.py
new file mode 100644
index 0000000..a86e7a2
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_conv.py
@@ -0,0 +1,200 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+1D Wasserstein barycenter demo
+==============================
+
+
+@author: rflamary
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
+import scipy as sp
+import scipy.signal as sps
+#%% parameters
+
+n=10 # nb bins
+
+# bin positions
+x=np.arange(n,dtype=np.float64)
+
+xx,yy=np.meshgrid(x,x)
+
+
+xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))
+
+M=ot.dist(xpos)
+
+
+I0=((xx-5)**2+(yy-5)**2<3**2)*1.0
+I1=((xx-7)**2+(yy-7)**2<3**2)*1.0
+
+I0/=I0.sum()
+I1/=I1.sum()
+
+i0=I0.ravel()
+i1=I1.ravel()
+
+M=M[i0>0,:][:,i1>0].copy()
+i0=i0[i0>0]
+i1=i1[i1>0]
+Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)
+
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.subplot(2,2,1)
+pl.imshow(I0)
+pl.subplot(2,2,2)
+pl.imshow(I1)
+
+
+#%% barycenter computation
+
+alpha=0.5 # 0<=alpha<=1
+weights=np.array([1-alpha,alpha])
+
+
+def conv2(I,k):
+ return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)
+
+def conv2n(I,k):
+ res=np.zeros_like(I)
+ for i in range(I.shape[2]):
+ res[:,:,i]=conv2(I[:,:,i],k)
+ return res
+
+
+def get_1Dkernel(reg,thr=1e-16,wmax=1024):
+ w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)
+ x=np.arange(w,dtype=np.float64)
+ return np.exp(-((x-w/2)**2)/reg)
+
+thr=1e-16
+reg=1e0
+
+k=get_1Dkernel(reg)
+pl.figure(2)
+pl.plot(k)
+
+I05=conv2(I0,k)
+
+pl.figure(1)
+pl.subplot(2,2,1)
+pl.imshow(I0)
+pl.subplot(2,2,2)
+pl.imshow(I05)
+
+#%%
+
+G=ot.emd(i0,i1,M)
+r0=np.sum(M*G)
+
+reg=1e-1
+Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)
+rs=np.sum(M*Gs)
+
+#%%
+
+def mylog(u):
+ tmp=np.log(u)
+ tmp[np.isnan(tmp)]=0
+ return tmp
+
+def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
+
+
+ a=np.asarray(a,dtype=np.float64)
+ b=np.asarray(b,dtype=np.float64)
+
+
+ if len(b.shape)>2:
+ nbb=b.shape[2]
+ a=a[:,:,np.newaxis]
+ else:
+ nbb=0
+
+
+ if log:
+ log={'err':[]}
+
+ # we assume that no distances are null except those of the diagonal of distances
+ if nbb:
+ u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))
+ v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))
+ a0=1.0/(np.prod(b.shape[:2]))
+ else:
+ u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))
+ v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))
+ a0=1.0/(np.prod(b.shape[:2]))
+
+
+ k=get_1Dkernel(reg)
+
+ if nbb:
+ K=lambda I: conv2n(I,k)
+ else:
+ K=lambda I: conv2(I,k)
+
+ cpt = 0
+ err=1
+ while (err>stopThr and cpt<numItermax):
+ uprev = u
+ vprev = v
+
+ v = np.divide(b, K(u))
+ u = np.divide(a, K(v))
+
+ if (np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ print('Warning: numerical errors at iteration', cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt%10==0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+
+ err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%200 ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+ cpt = cpt +1
+ if log:
+ log['u']=u
+ log['v']=v
+
+ if nbb: #return only loss
+ res=np.zeros((nbb))
+ for i in range(nbb):
+ res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
+ if log:
+ return res,log
+ else:
+ return res
+
+ else: # return OT matrix
+ res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))
+ if log:
+
+ return res,log
+ else:
+ return res
+
+reg=1e0
+r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)
+a=I0
+b=I1
+u=log['u']
+v=log['v']
+#%% barycenter interpolation