diff options
-rw-r--r-- | ot/__init__.py | 2 | ||||
-rw-r--r-- | ot/bregman/__init__.py | 4 | ||||
-rw-r--r-- | ot/bregman/sink.py | 91 | ||||
-rwxr-xr-x | setup.py | 2 |
4 files changed, 2 insertions, 97 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 0a9e89b..f5490f7 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -14,4 +14,4 @@ from bregman import sinkhorn # utils functions from utils import dist,dots,unif -__all__ = ["emd","sinkhorn","utils",'datasets','plot','dist','dots'] +__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','dots'] diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py deleted file mode 100644 index e4016ea..0000000 --- a/ot/bregman/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ - - -from .sink import sinkhorn - diff --git a/ot/bregman/sink.py b/ot/bregman/sink.py deleted file mode 100644 index 8b97e1e..0000000 --- a/ot/bregman/sink.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Fri Oct 21 09:40:21 2016 - -@author: rflamary -""" - -import numpy as np - - -def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): - """ - Solve the optimal transport problem (OT) - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the metric cost matrix - - Omega is the entropic regularization term - - a and b are the sample weights - - Parameters - ---------- - a : (ns,) ndarray - samples in the source domain - b : (nt,) ndarray - samples in the target domain - M : (ns,nt) ndarray - loss matrix - reg: float() - Regularization term >0 - - - Returns - ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - - """ - # init data - Nini = len(a) - Nfin = len(b) - - - cpt = 0 - - # we assume that no distances are null except those of the diagonal of distances - u = np.ones(Nini)/Nini - v = np.ones(Nfin)/Nfin - uprev=np.zeros(Nini) - vprev=np.zeros(Nini) - - #print reg - - K = np.exp(-M/reg) - #print np.min(K) - - Kp = np.dot(np.diag(1/a),K) - transp = K - cpt = 0 - err=1 - while (err>stopThr and cpt<numItermax): - if np.any(np.dot(K.T,u)==0) or np.any(np.isnan(u)) or np.any(np.isnan(v)): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errrors') - if cpt!=0: - u = uprev - v = vprev - break - uprev = u - vprev = v - v = np.divide(b,np.dot(K.T,u)) - u = 1./np.dot(Kp,v) - if cpt%10==0: - # we can speed up the process by checking for the error only all the 10th iterations - transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) - err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 - cpt = cpt +1 - #print 'err=',err,' cpt=',cpt - - return np.dot(np.diag(u),np.dot(K,np.diag(v))) - - @@ -20,7 +20,7 @@ setup(name='python-webgen', author=u'Remi Flamary', author_email='remi.flamary@gmail.com', url='https://github.com/rflamary/POT', - packages=['ot','ot.emd','ot.bregman'], + packages=['ot','ot.emd'], ext_modules = cythonize(Extension( "ot.emd.emd", # the extesion name sources=["ot/emd/emd.pyx", "ot/emd/EMD_wrap.cpp"], # the Cython source and |