diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 10:51:27 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 10:51:27 +0200 |
commit | 581c6de782dca279edd97778cc474e7597788c0f (patch) | |
tree | 760161e1c7812d8caf77bf8acc543453c6213e39 | |
parent | 2109443f5bea396114d1f9e0563ba5c396378c57 (diff) |
demo+sinkhorn
-rw-r--r-- | examples/demo_OT_1D.py | 102 | ||||
-rw-r--r-- | ot/__init__.py | 6 | ||||
-rw-r--r-- | ot/bregman/__init__.py | 4 | ||||
-rw-r--r-- | ot/bregman/sinkhorn.py | 91 | ||||
-rw-r--r-- | ot/datasets.py | 10 | ||||
-rw-r--r-- | ot/emd/emd.pyx | 16 | ||||
-rw-r--r-- | ot/utils.py | 15 | ||||
-rwxr-xr-x | setup.py | 4 |
8 files changed, 246 insertions, 2 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py new file mode 100644 index 0000000..29f2074 --- /dev/null +++ b/examples/demo_OT_1D.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 21 09:51:45 2016 + +@author: rflamary +""" + +import numpy as np +import matplotlib.pylab as pl +from matplotlib import gridspec +import ot + + + +#%% parameters + +n=100 # nb bins + +ma=20 # mean of a +mb=60 # mean of b + +sa=20 # std of a +sb=60 # std of b + +# bin positions +x=np.arange(n,dtype=np.float64) + +# Gaussian distributions +a=np.exp(-(x-ma)**2/(2*sa^2)) +b=np.exp(-(x-mb)**2/(2*sb^2)) + +# normalization +a/=a.sum() +b/=b.sum() + +# loss matrix +M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) +M/=M.max() + +#%% plot the distributions + +pl.figure(1) + +pl.plot(x,a,'b',label='Source distribution') +pl.plot(x,b,'r',label='Target distribution') + +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2) +gs = gridspec.GridSpec(3, 3) + +ax1=pl.subplot(gs[0,1:]) +pl.plot(x,b,'r',label='Target distribution') +pl.yticks(()) + +#pl.axis('off') + +ax2=pl.subplot(gs[1:,0]) +pl.plot(a,x,'b',label='Source distribution') +pl.gca().invert_xaxis() +pl.gca().invert_yaxis() +pl.xticks(()) +#pl.ylim((0,n)) +#pl.axis('off') + +pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) +pl.imshow(M,interpolation='nearest') + +pl.xlim((0,n)) +#pl.ylim((0,n)) +#pl.axis('off') + +#%% EMD + +G0=ot.emd(a,b,M) + +#%% plot EMD optimal tranport matrix +pl.figure(3) +gs = gridspec.GridSpec(3, 3) + +ax1=pl.subplot(gs[0,1:]) +pl.plot(x,b,'r',label='Target distribution') +pl.yticks(()) + +#pl.axis('off') + +ax2=pl.subplot(gs[1:,0]) +pl.plot(a,x,'b',label='Source distribution') +pl.gca().invert_xaxis() +pl.gca().invert_yaxis() +pl.xticks(()) +#pl.ylim((0,n)) +#pl.axis('off') + +pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) +pl.imshow(G0,interpolation='nearest') + +pl.xlim((0,n)) +#pl.ylim((0,n)) +#pl.axis('off')
\ No newline at end of file diff --git a/ot/__init__.py b/ot/__init__.py index d0ab3f7..beeae7f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,6 +1,12 @@ + from emd import emd +from bregman import sinkhorn + +import utils +import datasets +from utils import dist,dots __all__ = ["emd"] diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py new file mode 100644 index 0000000..e4016ea --- /dev/null +++ b/ot/bregman/__init__.py @@ -0,0 +1,4 @@ + + +from .sink import sinkhorn + diff --git a/ot/bregman/sinkhorn.py b/ot/bregman/sinkhorn.py new file mode 100644 index 0000000..798ac97 --- /dev/null +++ b/ot/bregman/sinkhorn.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 21 09:40:21 2016 + +@author: rflamary +""" + +import numpy as np + + +def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): + """ + Solve the optimal transport problem (OT) + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the metric cost matrix + - Omega is the entropic regularization term + - a and b are the sample weights + + Parameters + ---------- + a : (ns,) ndarray + samples in the source domain + b : (nt,) ndarray + samples in the target domain + M : (ns,nt) ndarray + loss matrix + reg: float() + Regularization term >0 + + + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters + + """ + # init data + Nini = len(a) + Nfin = len(b) + + + cpt = 0 + + # we assume that no distances are null except those of the diagonal of distances + u = np.ones(Nini)/Nini + v = np.ones(Nfin)/Nfin + uprev=np.zeros(Nini) + vprev=np.zeros(Nini) + + #print reg + + K = np.exp(-reg*M) + #print np.min(K) + + Kp = np.dot(np.diag(1/a),K) + transp = K + cpt = 0 + err=1 + while (err>stopThr and cpt<numItermax): + if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errrors') + if cpt!=0: + u = uprev + v = vprev + break + uprev = u + vprev = v + v = np.divide(b,np.dot(K.T,u)) + u = 1./np.dot(Kp,v) + if cpt%10==0: + # we can speed up the process by checking for the error only all the 10th iterations + transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + cpt = cpt +1 + #print 'err=',err,' cpt=',cpt + + return np.dot(np.diag(u),np.dot(K,np.diag(v))) + + diff --git a/ot/datasets.py b/ot/datasets.py new file mode 100644 index 0000000..bb10ba4 --- /dev/null +++ b/ot/datasets.py @@ -0,0 +1,10 @@ + +import numpy as np + + + +def get_1D_gauss(n,m,s): + "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) " + x=np.arange(n,dtype=np.float64) + h=np.exp(-(x-m)**2/(2*s^2)) + return h/h.sum()
\ No newline at end of file diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx index d090cea..e5ac8e0 100644 --- a/ot/emd/emd.pyx +++ b/ot/emd/emd.pyx @@ -19,6 +19,22 @@ cdef extern from "EMD.h": @cython.boundscheck(False) @cython.wraparound(False) def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): + """ + Solves the Earth Movers distance problem and returns the optimal transport matrix + + + + :param a: m weights of the source distribution (must sum to one) + :param b: n weights of the target distribution (must sum to one) + :param M: m x n cost matrix + :type a: np.ndarray + :type b: np.ndarray + :type M: np.ndarray + :return: Optimal transport matrix + :rtype: np.ndarray + + + """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] diff --git a/ot/utils.py b/ot/utils.py new file mode 100644 index 0000000..1a1c6b8 --- /dev/null +++ b/ot/utils.py @@ -0,0 +1,15 @@ + +import numpy as np +from scipy.spatial.distance import cdist, pdist + + +def dist(x1,x2=None,metric='sqeuclidean'): + """Compute distance between samples in x1 and x2""" + if x2 is None: + return pdist(x1,metric=metric) + else: + return cdist(x1,x2,metric=metric) + +def dots(*args): + """ Stupid but nice dots function for multiple matrix multiply """ + return reduce(np.dot,args)
\ No newline at end of file @@ -2,11 +2,11 @@ from distutils.core import setup, Extension import numpy -from Cython.Distutils import build_ext +#from Cython.Distutils import build_ext from Cython.Build import cythonize import os -import glob +#import glob version=0.1 |