diff options
-rw-r--r-- | ot/__init__.py | 7 | ||||
-rw-r--r-- | ot/bregman.py | 27 | ||||
-rw-r--r-- | ot/lp/emd.cpp | 2 | ||||
-rw-r--r-- | ot/optim.py | 58 |
4 files changed, 78 insertions, 16 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 87119e5..863f408 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,13 +1,14 @@ # Python Optimal Transport toolbox # All submodules and packages +from . import lp +from . import bregman +from . import optim from . import utils from . import datasets from . import plot -from . import bregman -from . import lp from . import da -from . import optim + # OT functions diff --git a/ot/bregman.py b/ot/bregman.py index b749b13..08f965b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ -Bregman projection for regularized Otimal transport +Bregman projections for regularized OT """ import numpy as np -def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): +def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem and return the OT matrix @@ -43,14 +43,18 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): Max number of iterations stopThr: float, optional Stop threshol on error (>0) - + verbose : int, optional + Print information along iterations + log : int, optional + record log if True Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters - + log: dict + log dictionary return only if log==True in parameters Examples -------- @@ -91,6 +95,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): cpt = 0 + if log: + log={'loss':[]} # we assume that no distances are null except those of the diagonal of distances u = np.ones(Nini)/Nini @@ -124,10 +130,19 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): # we can speed up the process by checking for the error only all the 10th iterations transp = np.dot(np.diag(u),np.dot(K,np.diag(v))) err = np.linalg.norm((np.sum(transp,axis=0)-b))**2 + if log: + log['loss'].append(err) + + if verbose: + if cpt%200 ==0: + print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19) + print('{:5d}|{:8e}|'.format(cpt,err)) cpt = cpt +1 #print 'err=',err,' cpt=',cpt - - return np.dot(np.diag(u),np.dot(K,np.diag(v))) + if log: + return np.dot(np.diag(u),np.dot(K,np.diag(v))),log + else: + return np.dot(np.diag(u),np.dot(K,np.diag(v))) def geometricBar(weights,alldistribT): diff --git a/ot/lp/emd.cpp b/ot/lp/emd.cpp index 26d243f..6db54bb 100644 --- a/ot/lp/emd.cpp +++ b/ot/lp/emd.cpp @@ -1229,7 +1229,7 @@ static PyObject *__pyx_codeobj__8; /* Python wrapper */ static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ -static char __pyx_doc_2ot_2lp_3emd_emd_c[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n source histogram \n b : (nt,) ndarray\n target histogram\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n "; +static char __pyx_doc_2ot_2lp_3emd_emd_c[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray, float64\n source histogram \n b : (nt,) ndarray, float64\n target histogram\n M : (ns,nt) ndarray, float64\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n "; static PyMethodDef __pyx_mdef_2ot_2lp_3emd_1emd_c = {"emd_c", (PyCFunction)__pyx_pw_2ot_2lp_3emd_1emd_c, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_2lp_3emd_emd_c}; static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { PyArrayObject *__pyx_v_a = 0; diff --git a/ot/optim.py b/ot/optim.py index e6373ce..d1bf672 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- """ -Created on Wed Oct 26 15:08:19 2016 - -@author: rflamary +Optimization algorithms for OT """ import numpy as np @@ -12,6 +10,42 @@ from lp import emd # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): + """ + Armijo linesearch function that works with matrices + + find an approximate minimum of f(xk+alpha*pk) that satifies the + armijo conditions. + + Parameters + ---------- + + f : function + loss function + xk : np.ndarray + initial position + pk : np.ndarray + descent direction + gfk : np.ndarray + gradient of f at xk + old_fval: float + loss value at xk + args : tuple, optional + arguments given to f + c1 : float, optional + c1 const in armijo rule (>0) + alpha0 : float, optional + initial step (>0) + + Returns + ------- + alpha : float + step that satisfy armijo conditions + fc : int + nb of function call + fa : float + loss value at step alpha + + """ xk = np.atleast_1d(xk) fc = [0] @@ -61,14 +95,26 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa samples in the target domain M : np.ndarray (ns,nt) loss matrix - reg: float() + reg : float Regularization term >0 - + G0 : np.ndarray (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : int, optional + Print information along iterations + log : int, optional + record log if True Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + log: dict + log dictionary return only if log==True in parameters + References ---------- @@ -77,7 +123,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa See Also -------- - ot.emd.emd : Unregularized optimal ransport + ot.lp.emd : Unregularized optimal ransport ot.bregman.sinkhorn : Entropic regularized optimal transport """ |