diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2019-06-04 11:57:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-04 11:57:18 +0200 |
commit | 5a6b226de20624b51c2ff98bc30e5611a7a788c7 (patch) | |
tree | 69b019aaa43ec7d69d97a48717eed27c01890c6e /ot | |
parent | f66ab58c7c895011fd37bafd3e848828399c56c4 (diff) | |
parent | 788a6506c9bf3b862a9652d74f65f8d07851e653 (diff) |
Merge pull request #86 from tvayer/master
[MRG] Gromov-Wasserstein closed form for linesearch and integration of Fused Gromov-Wasserstein
This PR closes #82
Thank you @tvayer for all the work.
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 1 | ||||
-rw-r--r-- | ot/gromov.py | 388 | ||||
-rw-r--r-- | ot/optim.py | 145 | ||||
-rw-r--r-- | ot/utils.py | 8 |
4 files changed, 499 insertions, 43 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index dc43834..321712b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -6,6 +6,7 @@ Bregman projections for regularized OT # Author: Remi Flamary <remi.flamary@unice.fr> # Nicolas Courty <ncourty@irisa.fr> # Kilian Fatras <kilian.fatras@irisa.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> # # License: MIT License diff --git a/ot/gromov.py b/ot/gromov.py index 7974546..ca96b31 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -9,17 +9,19 @@ Gromov-Wasserstein transport method # Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
import numpy as np
+
from .bregman import sinkhorn
-from .utils import dist
+from .utils import dist, UndefinedParameter
from .optim import cg
-def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
""" Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
@@ -32,12 +34,12 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): * C2 : Metric cost matrix in the target space
* T : A coupling between those two spaces
- The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ The square-loss function L(a,b)=|a-b|^2 is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)/2
- * f2(b)=(b^2)/2
+ * f1(a)=(a^2)
+ * f2(b)=(b^2)
* h1(a)=a
- * h2(b)=b
+ * h2(b)=2*b
The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
@@ -77,16 +79,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss':
def f1(a):
- return (a**2) / 2
+ return (a**2)
def f2(b):
- return (b**2) / 2
+ return (b**2)
def h1(a):
return a
def h2(b):
- return b
+ return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -268,7 +270,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -306,6 +308,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations
log : bool, optional
record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -329,9 +334,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): """
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -342,14 +345,161 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW transport between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -387,7 +537,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations
log : bool, optional
record log if True
-
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
Returns
-------
gw_dist : float
@@ -407,9 +559,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): """
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -418,7 +568,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -495,7 +645,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, T = np.outer(p, q) # Initialization
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
cpt = 0
err = 1
@@ -815,3 +965,197 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, cpt += 1
return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None):
+ """
+ Compute the fgw barycenter as presented eq (5) in [24].
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ masses of all samples
+ lambdas : list of float
+ list of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Wether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Wether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ initialization for the barycenters' structure matrix. If not set random init
+ init_X : ndarray, shape (N,d), optional
+ initialization for the barycenters' features. If not set random init
+ Returns
+ ----------
+ X : ndarray, shape (N,d)
+ Barycenters' features
+ C : ndarray, shape (N,N)
+ Barycenters' structure matrix
+ log_: dictionary
+ Only returned when log=True
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = np.ones(N) / N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = np.zeros((N, d))
+ else:
+ X = init_X
+
+ T = [np.outer(p, q) for q in ps]
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+ Returns
+ ----------
+ C : ndarray, shape (nt,nt)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ """
+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features
+ Returns
+ ----------
+ X : ndarray, shape (d,N)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ p = np.diag(np.array(1 / p).reshape(-1,))
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
+
+ return tmpsum
diff --git a/ot/optim.py b/ot/optim.py index f31fae2..f94aceb 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,6 +4,7 @@ Optimization algorithms for OT """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> # # License: MIT License @@ -72,8 +73,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 +def solve_linesearch(cost, G, deltaG, Mi, f_val, + armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + Cost in the FW for the linesearch + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optional + Structure matrix in the source domain. Only used and necessary when armijo=False + C2 : ndarray (nt,nt), optional + Structure matrix in the target domain. Only used and necessary when armijo=False + reg : float, optional + Regularization parameter. Only used and necessary when armijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [24]. Only used and necessary when armijo=False + M : ndarray (ns,nt), optional + Cost matrix between the features. Only used and necessary when armijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if armijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) + + return alpha, fc, f_val + + def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False): + stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -111,11 +174,15 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + **kwargs : dict + Parameters for linesearch Returns ------- @@ -157,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -177,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -185,8 +252,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -194,9 +262,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log @@ -205,7 +273,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, - numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): + numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with the generalized conditional gradient @@ -248,7 +316,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on error (>0) + Stop threshol on the relative variation (>0) + stopThr2 : float, optional + Stop threshol on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -294,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -322,8 +392,10 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if it >= numItermax: loop = 0 - delta_fval = (f_val - old_fval) / abs(f_val) - if abs(delta_fval) < stopThr: + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: loop = 0 if log: @@ -331,11 +403,42 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: return G, log else: return G + + +def solve_1d_linesearch_quad(a, b, c): + """ + For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + .. math:: + \argmin f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + + """ + f0 = c + df0 = b + f1 = a + f0 + df0 + + if a > 0: # convex + minimum = min(1, max(0, np.divide(-b, 2.0 * a))) + return minimum + else: # non convex + if f0 > f1: + return 1 + else: + return 0 diff --git a/ot/utils.py b/ot/utils.py index bb21b38..efd1288 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -487,3 +487,11 @@ class BaseEstimator(object): (key, self.__class__.__name__)) setattr(self, key, value) return self + + +class UndefinedParameter(Exception): + """ + Aim at raising an Exception when a undefined parameter is called + + """ + pass |