summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py139
1 files changed, 72 insertions, 67 deletions
diff --git a/ot/optim.py b/ot/optim.py
index 79f4f66..1d09adc 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -3,13 +3,19 @@
Optimization algorithms for OT
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
+
+
+def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
"""
Armijo linesearch function that works with matrices
@@ -51,20 +57,21 @@ def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
def phi(alpha1):
fc[0] += 1
- return f(xk + alpha1*pk, *args)
+ return f(xk + alpha1 * pk, *args)
if old_fval is None:
phi0 = phi(0.)
else:
phi0 = old_fval
- derphi0 = np.sum(pk*gfk) # Quickfix for matrices
- alpha,phi1 = scalar_search_armijo(phi,phi0,derphi0,c1=c1,alpha0=alpha0)
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ alpha, phi1 = scalar_search_armijo(
+ phi, phi0, derphi0, c1=c1, alpha0=alpha0)
- return alpha,fc[0],phi1
+ return alpha, fc[0], phi1
-def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False):
+def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with conditional gradient
@@ -128,74 +135,74 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
"""
- loop=1
+ loop = 1
if log:
- log={'loss':[]}
+ log = {'loss': []}
if G0 is None:
- G=np.outer(a,b)
+ G = np.outer(a, b)
else:
- G=G0
+ G = G0
def cost(G):
- return np.sum(M*G)+reg*f(G)
+ return np.sum(M * G) + reg * f(G)
- f_val=cost(G)
+ f_val = cost(G)
if log:
log['loss'].append(f_val)
- it=0
+ it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,f_val,0))
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
while loop:
- it+=1
- old_fval=f_val
-
+ it += 1
+ old_fval = f_val
# problem linearization
- Mi=M+reg*df(G)
+ Mi = M + reg * df(G)
# set M positive
- Mi+=Mi.min()
+ Mi += Mi.min()
# solve linear program
- Gc=emd(a,b,Mi)
+ Gc = emd(a, b, Mi)
- deltaG=Gc-G
+ deltaG = Gc - G
# line search
- alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val)
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
- G=G+alpha*deltaG
+ G = G + alpha * deltaG
# test convergence
- if it>=numItermax:
- loop=0
-
- delta_fval=(f_val-old_fval)/abs(f_val)
- if abs(delta_fval)<stopThr:
- loop=0
+ if it >= numItermax:
+ loop = 0
+ delta_fval = (f_val - old_fval) / abs(f_val)
+ if abs(delta_fval) < stopThr:
+ loop = 0
if log:
log['loss'].append(f_val)
if verbose:
- if it%20 ==0:
- print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,f_val,delta_fval))
-
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
if log:
- return G,log
+ return G, log
else:
return G
-def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False):
+
+def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -264,70 +271,68 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopT
"""
- loop=1
+ loop = 1
if log:
- log={'loss':[]}
+ log = {'loss': []}
if G0 is None:
- G=np.outer(a,b)
+ G = np.outer(a, b)
else:
- G=G0
+ G = G0
def cost(G):
- return np.sum(M*G)+ reg1*np.sum(G*np.log(G)) + reg2*f(G)
+ return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
- f_val=cost(G)
+ f_val = cost(G)
if log:
log['loss'].append(f_val)
- it=0
+ it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,f_val,0))
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
while loop:
- it+=1
- old_fval=f_val
-
+ it += 1
+ old_fval = f_val
# problem linearization
- Mi=M+reg2*df(G)
+ Mi = M + reg2 * df(G)
# solve linear program with Sinkhorn
- #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
- Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax)
+ # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
+ Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
- deltaG=Gc-G
+ deltaG = Gc - G
# line search
- dcost=Mi+reg1*(1+np.log(G)) #??
- alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val)
+ dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
- G=G+alpha*deltaG
+ G = G + alpha * deltaG
# test convergence
- if it>=numItermax:
- loop=0
-
- delta_fval=(f_val-old_fval)/abs(f_val)
- if abs(delta_fval)<stopThr:
- loop=0
+ if it >= numItermax:
+ loop = 0
+ delta_fval = (f_val - old_fval) / abs(f_val)
+ if abs(delta_fval) < stopThr:
+ loop = 0
if log:
log['loss'].append(f_val)
if verbose:
- if it%20 ==0:
- print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,f_val,delta_fval))
-
+ if it % 20 == 0:
+ print('{:5s}|{:12s}|{:8s}'.format(
+ 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
+ print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
if log:
- return G,log
+ return G, log
else:
return G
-