From 9421dddd8890d4c575b593d678eb7bdf5f933f83 Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 15:51:57 +0200 Subject: Doc+armijo --- ot/gromov.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 44248d1..5a57dc8 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -33,12 +33,12 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): * C2 : Metric cost matrix in the target space * T : A coupling between those two spaces - The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as : + The square-loss function L(a,b)=|a-b|^2 is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : - * f1(a)=(a^2)/2 - * f2(b)=(b^2)/2 + * f1(a)=(a^2) + * f2(b)=(b^2) * h1(a)=a - * h2(b)=b + * h2(b)=2*b The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : @@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -307,8 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs) Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -344,14 +344,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs) return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs): """ Computes the FGW distance between two graphs see [3] .. math:: @@ -363,6 +363,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - M is the (ns,nt) metric cost matrix - :math:`f` is the regularization term ( and df is its gradient) - a and b are source and target weights (sum to 1) + - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ Parameters ---------- @@ -386,8 +387,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -415,10 +416,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def df(G): return gwggrad(constC, hC1, hC2, G) - return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -456,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs Print information along iterations log : bool, optional record log if True - amijo : bool, optional - If True the steps of the line-search is found via an amijo research. Else closed form is used. + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. Returns ------- @@ -487,7 +488,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -890,7 +891,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=True, init_C=None, init_X=None): """ - Compute the fgw barycenter as presented eq (5) in [3]. + Compute the fgw barycenter as presented eq (5) in [24]. ---------- N : integer Desired number of samples of the target barycenter @@ -1065,7 +1066,7 @@ def update_sructure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): """ - Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] + Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24] calculated at each iteration Parameters ---------- -- cgit v1.2.3