diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 10:51:27 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-21 10:51:27 +0200 |
commit | 581c6de782dca279edd97778cc474e7597788c0f (patch) | |
tree | 760161e1c7812d8caf77bf8acc543453c6213e39 /ot | |
parent | 2109443f5bea396114d1f9e0563ba5c396378c57 (diff) |
demo+sinkhorn
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 6 | ||||
-rw-r--r-- | ot/bregman/__init__.py | 4 | ||||
-rw-r--r-- | ot/bregman/sinkhorn.py | 91 | ||||
-rw-r--r-- | ot/datasets.py | 10 | ||||
-rw-r--r-- | ot/emd/emd.pyx | 16 | ||||
-rw-r--r-- | ot/utils.py | 15 |
6 files changed, 142 insertions, 0 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index d0ab3f7..beeae7f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,6 +1,12 @@ + from emd import emd +from bregman import sinkhorn + +import utils +import datasets +from utils import dist,dots __all__ = ["emd"] diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py new file mode 100644 index 0000000..e4016ea --- /dev/null +++ b/ot/bregman/__init__.py @@ -0,0 +1,4 @@ + + +from .sink import sinkhorn + diff --git a/ot/bregman/sinkhorn.py b/ot/bregman/sinkhorn.py new file mode 100644 index 0000000..798ac97 --- /dev/null +++ b/ot/bregman/sinkhorn.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 21 09:40:21 2016 + +@author: rflamary +""" + +import numpy as np + + +def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): + """ + Solve the optimal transport problem (OT) + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the metric cost matrix + - Omega is the entropic regularization term + - a and b are the sample weights + + Parameters + ---------- + a : (ns,) ndarray + samples in the source domain + b : (nt,) ndarray + samples in the target domain + M : (ns,nt) ndarray + loss matrix + reg: float() + Regularization term >0 + + + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters + + """ + # init data + Nini = len(a) + Nfin = len(b) + + + cpt = 0 + + # we assume that no distances are null except those of the diagonal of distances + u = np.ones(Nini)/Nini + v = np.ones(Nfin)/Nfin + uprev=np.zeros(Nini) + vprev=np.zeros(Nini) + + #print reg + + K = np.exp(-reg*M) + #print np.min(K) + + Kp = np.dot(np.diag(1/a),K) + transp = K + cpt = 0 + err=1 + while (err>stopThr and cpt<numItermax): + if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): + # we have reached the machine precision + # come back to previous solution and quit loop + print('Warning: numerical errrors') + if cpt!=0: + u = uprev + v = vprev + break + uprev = u + vprev = v + v = np.divide(b,np.dot(K.T,u)) + u = 1./np.dot(Kp,v) + if cpt%10==0: + # we can speed up the process by checking for the error only all the 10th iterations + transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) + err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + cpt = cpt +1 + #print 'err=',err,' cpt=',cpt + + return np.dot(np.diag(u),np.dot(K,np.diag(v))) + + diff --git a/ot/datasets.py b/ot/datasets.py new file mode 100644 index 0000000..bb10ba4 --- /dev/null +++ b/ot/datasets.py @@ -0,0 +1,10 @@ + +import numpy as np + + + +def get_1D_gauss(n,m,s): + "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) " + x=np.arange(n,dtype=np.float64) + h=np.exp(-(x-m)**2/(2*s^2)) + return h/h.sum()
\ No newline at end of file diff --git a/ot/emd/emd.pyx b/ot/emd/emd.pyx index d090cea..e5ac8e0 100644 --- a/ot/emd/emd.pyx +++ b/ot/emd/emd.pyx @@ -19,6 +19,22 @@ cdef extern from "EMD.h": @cython.boundscheck(False) @cython.wraparound(False) def emd( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M): + """ + Solves the Earth Movers distance problem and returns the optimal transport matrix + + + + :param a: m weights of the source distribution (must sum to one) + :param b: n weights of the target distribution (must sum to one) + :param M: m x n cost matrix + :type a: np.ndarray + :type b: np.ndarray + :type M: np.ndarray + :return: Optimal transport matrix + :rtype: np.ndarray + + + """ cdef int n1= M.shape[0] cdef int n2= M.shape[1] diff --git a/ot/utils.py b/ot/utils.py new file mode 100644 index 0000000..1a1c6b8 --- /dev/null +++ b/ot/utils.py @@ -0,0 +1,15 @@ + +import numpy as np +from scipy.spatial.distance import cdist, pdist + + +def dist(x1,x2=None,metric='sqeuclidean'): + """Compute distance between samples in x1 and x2""" + if x2 is None: + return pdist(x1,metric=metric) + else: + return cdist(x1,x2,metric=metric) + +def dots(*args): + """ Stupid but nice dots function for multiple matrix multiply """ + return reduce(np.dot,args)
\ No newline at end of file |