diff options
author | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2016-11-08 23:15:49 +0100 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2016-11-08 23:15:49 +0100 |
commit | 22036bd9db2cd5cc8329b27ca740ff4d9c114fb7 (patch) | |
tree | b964586e8069b7197652febcafb6ab2918c0d33a /ot/optim.py | |
parent | 0d1f3eb3c41c0b06edc70647037b6cda581e8e2d (diff) |
da with GL
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/ot/optim.py b/ot/optim.py index 598e23f..d807824 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -7,6 +7,7 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn_stabilized +from .bregman import sinkhorn # The corresponding scipy function does not work for matrices def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99): @@ -195,7 +196,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa else: return G -def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False): +def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False): """ Solve the general regularized OT problem with the generalized conditional gradient @@ -235,6 +236,8 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False initial guess (default is indep joint density) numItermax : int, optional Max number of iterations + numInnerItermax : int, optional + Max number of iterations of Sinkhorn stopThr : float, optional Stop threshol on error (>0) verbose : bool, optional @@ -293,16 +296,16 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False # problem linearization Mi=M+reg2*df(G) - # set M positive - Mi+=Mi.min() # solve linear program with Sinkhorn - Gc = sinkhorn_stabilized(a,b, Mi, reg1) + #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax) + Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax) deltaG=Gc-G # line search - alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val) + dcost=Mi+reg1*np.sum(deltaG*(1+np.log(G))) #?? + alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val) G=G+alpha*deltaG |