summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:16:23 +0200
commitf70aabfcc11f92181e0dc987b341bad8ec030d75 (patch)
tree3f7209a6f8421294fe030cab3fbad49904413e4e /ot/optim.py
parent6484c9ea301fc15ae53b4afe134941909f581ffe (diff)
pep8
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py59
1 files changed, 29 insertions, 30 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 9fce21e..cbfb187 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -71,8 +71,9 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
-def do_linesearch(cost,G,deltaG,Mi,f_val,
- amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None):
+
+def do_linesearch(cost, G, deltaG, Mi, f_val,
+ amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
Parameters
@@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val,
"""
if amijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
- else: # requires symetric matrices
- dot1=np.dot(C1,deltaG)
- dot12=dot1.dot(C2)
- a=-2*reg*np.sum(dot12*deltaG)
- b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG))
- c=cost(G)
+ else: # requires symetric matrices
+ dot1 = np.dot(C1, deltaG)
+ dot12 = dot1.dot(C2)
+ a = -2 * reg * np.sum(dot12 * deltaG)
+ b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ c = cost(G)
+
+ alpha = solve_1d_linesearch_quad_funct(a, b, c)
+ fc = None
+ f_val = cost(G + alpha * deltaG)
- alpha=solve_1d_linesearch_quad_funct(a,b,c)
- fc=None
- f_val=cost(G+alpha*deltaG)
-
- return alpha,fc,f_val
+ return alpha, fc, f_val
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
- stopThr=1e-9, verbose=False, log=False,**kwargs):
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the general regularized OT problem with conditional gradient
@@ -240,7 +241,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs)
+ alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
G = G + alpha * deltaG
@@ -403,11 +404,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
else:
return G
-def solve_1d_linesearch_quad_funct(a,b,c):
+
+def solve_1d_linesearch_quad_funct(a, b, c):
"""
- Solve on 0,1 the following problem:
+ Solve on 0,1 the following problem:
.. math::
- \min f(x)=a*x^{2}+b*x+c
+ \min f(x)=a*x^{2}+b*x+c
Parameters
----------
@@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c):
Returns
-------
- x : float
+ x : float
The optimal value which leads to the minimal cost
-
+
"""
- f0=c
- df0=b
- f1=a+f0+df0
+ f0 = c
+ df0 = b
+ f1 = a + f0 + df0
- if a>0: # convex
- minimum=min(1,max(0,-b/(2*a)))
- #print('entrelesdeux')
+ if a > 0: # convex
+ minimum = min(1, max(0, -b / (2 * a)))
return minimum
- else: # non convexe donc sur les coins
- if f0>f1:
- #print('sur1 f(1)={}'.format(f(1)))
+ else: # non convex
+ if f0 > f1:
return 1
else:
- #print('sur0 f(0)={}'.format(f(0)))
return 0