summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2016-11-08 23:15:49 +0100
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2016-11-08 23:15:49 +0100
commit22036bd9db2cd5cc8329b27ca740ff4d9c114fb7 (patch)
treeb964586e8069b7197652febcafb6ab2918c0d33a /ot/optim.py
parent0d1f3eb3c41c0b06edc70647037b6cda581e8e2d (diff)
da with GL
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 598e23f..d807824 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -7,6 +7,7 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn_stabilized
+from .bregman import sinkhorn
# The corresponding scipy function does not work for matrices
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
@@ -195,7 +196,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
else:
return G
-def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False):
+def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -235,6 +236,8 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
@@ -293,16 +296,16 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False
# problem linearization
Mi=M+reg2*df(G)
- # set M positive
- Mi+=Mi.min()
# solve linear program with Sinkhorn
- Gc = sinkhorn_stabilized(a,b, Mi, reg1)
+ #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
+ Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax)
deltaG=Gc-G
# line search
- alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val)
+ dcost=Mi+reg1*np.sum(deltaG*(1+np.log(G))) #??
+ alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val)
G=G+alpha*deltaG