# -*- coding: utf-8 -*- """ ============================== 1D Wasserstein barycenter demo ============================== @author: rflamary """ import numpy as np import matplotlib.pylab as pl import ot from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used import scipy as sp import scipy.signal as sps #%% parameters n=10 # nb bins # bin positions x=np.arange(n,dtype=np.float64) xx,yy=np.meshgrid(x,x) xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1))) M=ot.dist(xpos) I0=((xx-5)**2+(yy-5)**2<3**2)*1.0 I1=((xx-7)**2+(yy-7)**2<3**2)*1.0 I0/=I0.sum() I1/=I1.sum() i0=I0.ravel() i1=I1.ravel() M=M[i0>0,:][:,i1>0].copy() i0=i0[i0>0] i1=i1[i1>0] Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2) #%% plot the distributions pl.figure(1) pl.subplot(2,2,1) pl.imshow(I0) pl.subplot(2,2,2) pl.imshow(I1) #%% barycenter computation alpha=0.5 # 0<=alpha<=1 weights=np.array([1-alpha,alpha]) def conv2(I,k): return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0) def conv2n(I,k): res=np.zeros_like(I) for i in range(I.shape[2]): res[:,:,i]=conv2(I[:,:,i],k) return res def get_1Dkernel(reg,thr=1e-16,wmax=1024): w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3) x=np.arange(w,dtype=np.float64) return np.exp(-((x-w/2)**2)/reg) thr=1e-16 reg=1e0 k=get_1Dkernel(reg) pl.figure(2) pl.plot(k) I05=conv2(I0,k) pl.figure(1) pl.subplot(2,2,1) pl.imshow(I0) pl.subplot(2,2,2) pl.imshow(I05) #%% G=ot.emd(i0,i1,M) r0=np.sum(M*G) reg=1e-1 Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg) rs=np.sum(M*Gs) #%% def mylog(u): tmp=np.log(u) tmp[np.isnan(tmp)]=0 return tmp def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs): a=np.asarray(a,dtype=np.float64) b=np.asarray(b,dtype=np.float64) if len(b.shape)>2: nbb=b.shape[2] a=a[:,:,np.newaxis] else: nbb=0 if log: log={'err':[]} # we assume that no distances are null except those of the diagonal of distances if nbb: u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2])) v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2])) a0=1.0/(np.prod(b.shape[:2])) else: u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2])) v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2])) a0=1.0/(np.prod(b.shape[:2])) k=get_1Dkernel(reg) if nbb: K=lambda I: conv2n(I,k) else: K=lambda I: conv2(I,k) cpt = 0 err=1 while (err>stopThr and cpt