From adbf95ee9a720fa38b5b91d0a9d5c3b22ba0b226 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 3 Jul 2017 16:33:00 +0200 Subject: update doc --- docs/source/auto_examples/plot_OT_conv.py | 200 ++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 docs/source/auto_examples/plot_OT_conv.py (limited to 'docs/source/auto_examples/plot_OT_conv.py') diff --git a/docs/source/auto_examples/plot_OT_conv.py b/docs/source/auto_examples/plot_OT_conv.py new file mode 100644 index 0000000..a86e7a2 --- /dev/null +++ b/docs/source/auto_examples/plot_OT_conv.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +============================== +1D Wasserstein barycenter demo +============================== + + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +import ot +from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used +import scipy as sp +import scipy.signal as sps +#%% parameters + +n=10 # nb bins + +# bin positions +x=np.arange(n,dtype=np.float64) + +xx,yy=np.meshgrid(x,x) + + +xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) + +M=ot.dist(xpos) + + +I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 +I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 + +I0/=I0.sum() +I1/=I1.sum() + +i0=I0.ravel() +i1=I1.ravel() + +M=M[i0>0,:][:,i1>0].copy() +i0=i0[i0>0] +i1=i1[i1>0] +Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) + + +#%% plot the distributions + +pl.figure(1) +pl.subplot(2,2,1) +pl.imshow(I0) +pl.subplot(2,2,2) +pl.imshow(I1) + + +#%% barycenter computation + +alpha=0.5 # 0<=alpha<=1 +weights=np.array([1-alpha,alpha]) + + +def conv2(I,k): + return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) + +def conv2n(I,k): + res=np.zeros_like(I) + for i in range(I.shape[2]): + res[:,:,i]=conv2(I[:,:,i],k) + return res + + +def get_1Dkernel(reg,thr=1e-16,wmax=1024): + w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) + x=np.arange(w,dtype=np.float64) + return np.exp(-((x-w/2)**2)/reg) + +thr=1e-16 +reg=1e0 + +k=get_1Dkernel(reg) +pl.figure(2) +pl.plot(k) + +I05=conv2(I0,k) + +pl.figure(1) +pl.subplot(2,2,1) +pl.imshow(I0) +pl.subplot(2,2,2) +pl.imshow(I05) + +#%% + +G=ot.emd(i0,i1,M) +r0=np.sum(M*G) + +reg=1e-1 +Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) +rs=np.sum(M*Gs) + +#%% + +def mylog(u): + tmp=np.log(u) + tmp[np.isnan(tmp)]=0 + return tmp + +def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): + + + a=np.asarray(a,dtype=np.float64) + b=np.asarray(b,dtype=np.float64) + + + if len(b.shape)>2: + nbb=b.shape[2] + a=a[:,:,np.newaxis] + else: + nbb=0 + + + if log: + log={'err':[]} + + # we assume that no distances are null except those of the diagonal of distances + if nbb: + u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + else: + u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) + v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) + a0=1.0/(np.prod(b.shape[:2])) + + + k=get_1Dkernel(reg) + + if nbb: + K=lambda I: conv2n(I,k) + else: + K=lambda I: conv2(I,k) + + cpt = 0 + err=1 + while (err>stopThr and cpt