diff options
Diffstat (limited to 'docs/source/auto_examples/plot_OT_conv.py')
-rw-r--r-- | docs/source/auto_examples/plot_OT_conv.py | 200 |
1 files changed, 0 insertions, 200 deletions
diff --git a/docs/source/auto_examples/plot_OT_conv.py b/docs/source/auto_examples/plot_OT_conv.py deleted file mode 100644 index a86e7a2..0000000 --- a/docs/source/auto_examples/plot_OT_conv.py +++ /dev/null @@ -1,200 +0,0 @@ -# -*- coding: utf-8 -*- -""" -============================== -1D Wasserstein barycenter demo -============================== - - -@author: rflamary -""" - -import numpy as np -import matplotlib.pylab as pl -import ot -from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used -import scipy as sp -import scipy.signal as sps -#%% parameters - -n=10 # nb bins - -# bin positions -x=np.arange(n,dtype=np.float64) - -xx,yy=np.meshgrid(x,x) - - -xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) - -M=ot.dist(xpos) - - -I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 -I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 - -I0/=I0.sum() -I1/=I1.sum() - -i0=I0.ravel() -i1=I1.ravel() - -M=M[i0>0,:][:,i1>0].copy() -i0=i0[i0>0] -i1=i1[i1>0] -Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) - - -#%% plot the distributions - -pl.figure(1) -pl.subplot(2,2,1) -pl.imshow(I0) -pl.subplot(2,2,2) -pl.imshow(I1) - - -#%% barycenter computation - -alpha=0.5 # 0<=alpha<=1 -weights=np.array([1-alpha,alpha]) - - -def conv2(I,k): - return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) - -def conv2n(I,k): - res=np.zeros_like(I) - for i in range(I.shape[2]): - res[:,:,i]=conv2(I[:,:,i],k) - return res - - -def get_1Dkernel(reg,thr=1e-16,wmax=1024): - w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) - x=np.arange(w,dtype=np.float64) - return np.exp(-((x-w/2)**2)/reg) - -thr=1e-16 -reg=1e0 - -k=get_1Dkernel(reg) -pl.figure(2) -pl.plot(k) - -I05=conv2(I0,k) - -pl.figure(1) -pl.subplot(2,2,1) -pl.imshow(I0) -pl.subplot(2,2,2) -pl.imshow(I05) - -#%% - -G=ot.emd(i0,i1,M) -r0=np.sum(M*G) - -reg=1e-1 -Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) -rs=np.sum(M*Gs) - -#%% - -def mylog(u): - tmp=np.log(u) - tmp[np.isnan(tmp)]=0 - return tmp - -def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): - - - a=np.asarray(a,dtype=np.float64) - b=np.asarray(b,dtype=np.float64) - - - if len(b.shape)>2: - nbb=b.shape[2] - a=a[:,:,np.newaxis] - else: - nbb=0 - - - if log: - log={'err':[]} - - # we assume that no distances are null except those of the diagonal of distances - if nbb: - u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) - v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) - a0=1.0/(np.prod(b.shape[:2])) - else: - u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) - v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) - a0=1.0/(np.prod(b.shape[:2])) - - - k=get_1Dkernel(reg) - - if nbb: - K=lambda I: conv2n(I,k) - else: - K=lambda I: conv2(I,k) - - cpt = 0 - err=1 - while (err>stopThr and cpt<numItermax): - uprev = u - vprev = v - - v = np.divide(b, K(u)) - u = np.divide(a, K(v)) - - if (np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) - u = uprev - v = vprev - break - if cpt%10==0: - # we can speed up the process by checking for the error only all the 10th iterations - - err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2) - - if log: - log['err'].append(err) - - if verbose: - if cpt%200 ==0: - print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) - print('{:5d}|{:8e}|'.format(cpt,err)) - cpt = cpt +1 - if log: - log['u']=u - log['v']=v - - if nbb: #return only loss - res=np.zeros((nbb)) - for i in range(nbb): - res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M) - if log: - return res,log - else: - return res - - else: # return OT matrix - res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0))) - if log: - - return res,log - else: - return res - -reg=1e0 -r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True) -a=I0 -b=I1 -u=log['u'] -v=log['v'] -#%% barycenter interpolation |