summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-28 16:08:41 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-28 16:08:41 +0200
commit549b95b5736b42f3fe74daf9805303a08b1ae01d (patch)
treed4d8ac5252bff2fef688e2fc81087293364b3ac7 /ot/optim.py
parent327b0c6e0ccb0c9453179eb316021c34bcdffec4 (diff)
FGW+gromov changes
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py102
1 files changed, 99 insertions, 3 deletions
diff --git a/ot/optim.py b/ot/optim.py
index f31fae2..a774865 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -4,7 +4,7 @@ Optimization algorithms for OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
-#
+# Titouan Vayer <titouan.vayer@irisa.fr>
# License: MIT License
import numpy as np
@@ -71,9 +71,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
+def do_linesearch(cost,G,deltaG,Mi,f_val,
+ amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None):
+ """
+ Solve the linesearch in the FW iterations
+ Parameters
+ ----------
+ cost : method
+ The FGW cost
+ G : ndarray, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : ndarray (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ Mi : ndarray (ns,nt)
+ Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
+ f_val : float
+ Value of the cost at G
+ amijo : bool, optionnal
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : ndarray (ns,ns), optionnal
+ Structure matrix in the source domain. Only used when amijo=False
+ C2 : ndarray (nt,nt), optionnal
+ Structure matrix in the target domain. Only used when amijo=False
+ reg : float, optionnal
+ Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False
+ Gc : ndarray (ns,nt)
+ Optimal map found by linearization in the FW algorithm. Only used when amijo=False
+ constC : ndarray (ns,nt)
+ Constant for the gromov cost. See [3]. Only used when amijo=False
+ M : ndarray (ns,nt), optionnal
+ Cost matrix between the features. Only used when amijo=False
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+ References
+ ----------
+ .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if amijo:
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ else: # requires symetric matrices
+ dot1=np.dot(C1,deltaG)
+ dot12=dot1.dot(C2)
+ a=-2*reg*np.sum(dot12*deltaG)
+ b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG))
+ c=cost(G)
+
+ alpha=solve_1d_linesearch_quad_funct(a,b,c)
+ fc=None
+ f_val=cost(G+alpha*deltaG)
+
+ return alpha,fc,f_val
+
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
- stopThr=1e-9, verbose=False, log=False):
+ stopThr=1e-9, verbose=False, log=False,**kwargs):
"""
Solve the general regularized OT problem with conditional gradient
@@ -116,6 +177,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Print information along iterations
log : bool, optional
record log if True
+ kwargs : dict
+ Parameters for linesearch
Returns
-------
@@ -177,7 +240,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs)
G = G + alpha * deltaG
@@ -339,3 +402,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
return G, log
else:
return G
+
+def solve_1d_linesearch_quad_funct(a,b,c):
+ """
+ Solve on 0,1 the following problem:
+ .. math::
+ \min f(x)=a*x^{2}+b*x+c
+
+ Parameters
+ ----------
+ a,b,c : float
+ The coefficients of the quadratic function
+
+ Returns
+ -------
+ x : float
+ The optimal value which leads to the minimal cost
+
+ """
+ f0=c
+ df0=b
+ f1=a+f0+df0
+
+ if a>0: # convex
+ minimum=min(1,max(0,-b/(2*a)))
+ #print('entrelesdeux')
+ return minimum
+ else: # non convexe donc sur les coins
+ if f0>f1:
+ #print('sur1 f(1)={}'.format(f(1)))
+ return 1
+ else:
+ #print('sur0 f(0)={}'.format(f(0)))
+ return 0