.. _sphx_glr_auto_examples_plot_OT_conv.py: ============================== 1D Wasserstein barycenter demo ============================== @author: rflamary .. code-block:: pytb Traceback (most recent call last): File "/home/rflamary/.local/lib/python2.7/site-packages/sphinx_gallery/gen_rst.py", line 518, in execute_code_block exec(code_block, example_globals) File "", line 86, in TypeError: unsupported operand type(s) for *: 'float' and 'Mock' .. code-block:: python import numpy as np import matplotlib.pylab as pl import ot from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used import scipy as sp import scipy.signal as sps #%% parameters n=10 # nb bins # bin positions x=np.arange(n,dtype=np.float64) xx,yy=np.meshgrid(x,x) xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) M=ot.dist(xpos) I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 I0/=I0.sum() I1/=I1.sum() i0=I0.ravel() i1=I1.ravel() M=M[i0>0,:][:,i1>0].copy() i0=i0[i0>0] i1=i1[i1>0] Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) #%% plot the distributions pl.figure(1) pl.subplot(2,2,1) pl.imshow(I0) pl.subplot(2,2,2) pl.imshow(I1) #%% barycenter computation alpha=0.5 # 0<=alpha<=1 weights=np.array([1-alpha,alpha]) def conv2(I,k): return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) def conv2n(I,k): res=np.zeros_like(I) for i in range(I.shape[2]): res[:,:,i]=conv2(I[:,:,i],k) return res def get_1Dkernel(reg,thr=1e-16,wmax=1024): w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) x=np.arange(w,dtype=np.float64) return np.exp(-((x-w/2)**2)/reg) thr=1e-16 reg=1e0 k=get_1Dkernel(reg) pl.figure(2) pl.plot(k) I05=conv2(I0,k) pl.figure(1) pl.subplot(2,2,1) pl.imshow(I0) pl.subplot(2,2,2) pl.imshow(I05) #%% G=ot.emd(i0,i1,M) r0=np.sum(M*G) reg=1e-1 Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) rs=np.sum(M*Gs) #%% def mylog(u): tmp=np.log(u) tmp[np.isnan(tmp)]=0 return tmp def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) if len(b.shape)>2: nbb=b.shape[2] a=a[:,:,np.newaxis] else: nbb=0 if log: log={'err':[]} # we assume that no distances are null except those of the diagonal of distances if nbb: u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) a0=1.0/(np.prod(b.shape[:2])) else: u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) a0=1.0/(np.prod(b.shape[:2])) k=get_1Dkernel(reg) if nbb: K=lambda I: conv2n(I,k) else: K=lambda I: conv2(I,k) cpt = 0 err=1 while (err>stopThr and cpt` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_OT_conv.ipynb ` .. rst-class:: sphx-glr-signature `Generated by Sphinx-Gallery `_