From e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:05:38 +0200 Subject: code review1 --- ot/gromov.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++------- ot/optim.py | 31 ++++++++--------- 2 files changed, 112 insertions(+), 27 deletions(-) (limited to 'ot') diff --git a/ot/gromov.py b/ot/gromov.py index 5a57dc8..53349b7 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -10,6 +10,7 @@ Gromov-Wasserstein transport method # Nicolas Courty # RĂ©mi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -351,9 +352,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ - Computes the FGW distance between two graphs see [3] + Computes the FGW transport between two graphs see [24] .. math:: \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} s.t. \gamma 1 = p @@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, distribution in the source space q : ndarray, shape (nt,) distribution in the target space - loss_fun : string,optionnal + loss_fun : string,optional loss function used for the solver max_iter : int, optional Max number of iterations @@ -416,7 +417,86 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, def df(G): return gwggrad(constC, hC1, hC2, G) - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + if log: + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + log['fgw_dist'] = log['loss'][::-1][0] + return res, log + else: + return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): + """ + Computes the FGW distance between two graphs see [24] + .. math:: + \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p + \gamma^T 1= q + \gamma\geq 0 + where : + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + - L is a loss function to account for the misfit between the similarity matrices + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters + ---------- + M : ndarray, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix respresentative of the structure in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix espresentative of the structure in the target space + p : ndarray, shape (ns,) + distribution in the source space + q : ndarray, shape (nt,) + distribution in the target space + loss_fun : string,optional + loss function used for the solver + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + **kwargs : dict + parameters can be directly pased to the ot.optim.cg solver + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + + res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + if log: + log['fgw_dist'] = log['loss'][::-1][0] + log['T'] = res + return log['fgw_dist'], log + else: + return log['fgw_dist'] def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): @@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, - verbose=False, log=True, init_C=None, init_X=None): + verbose=False, log=False, init_C=None, init_X=None): """ Compute the fgw barycenter as presented eq (5) in [24]. ---------- @@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Barycenters' features C : ndarray, shape (N,N) Barycenters' structure matrix - log_: + log_: dictionary + Only returned when log=True T : list of (N,ns) transport matrices Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) References @@ -1015,14 +1096,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns - - log_['Ts_iter'].append(T) err_feature = np.linalg.norm(X - Xprev.reshape(N, d)) err_structure = np.linalg.norm(C - Cprev) if log: log_['err_feature'].append(err_feature) log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) if verbose: if cpt % 200 == 0: @@ -1032,11 +1112,15 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T'] = T # from target to Ys - log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms # Ms are N,ns - return X, C, log_ + if log: + return X, C, log_ + else: + return X, C def update_sructure_matrix(p, lambdas, T, Cs): diff --git a/ot/optim.py b/ot/optim.py index 7d103e2..4d428d9 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -5,6 +5,7 @@ Optimization algorithms for OT # Author: Remi Flamary # Titouan Vayer +# # License: MIT License import numpy as np @@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val, Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost f_val : float Value of the cost at G - armijo : bool, optionnal + armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. - C1 : ndarray (ns,ns), optionnal + C1 : ndarray (ns,ns), optional Structure matrix in the source domain. Only used when armijo=False - C2 : ndarray (nt,nt), optionnal + C2 : ndarray (nt,nt), optional Structure matrix in the target domain. Only used when armijo=False - reg : float, optionnal + reg : float, optional Regularization parameter. Only used when armijo=False Gc : ndarray (ns,nt) Optimal map found by linearization in the FW algorithm. Only used when armijo=False constC : ndarray (ns,nt) Constant for the gromov cost. See [24]. Only used when armijo=False - M : ndarray (ns,nt), optionnal + M : ndarray (ns,nt), optional Cost matrix between the features. Only used when armijo=False Returns ------- @@ -223,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -261,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: @@ -363,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, it = 0 if verbose: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32) - print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0)) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) + print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0)) while loop: @@ -402,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, if verbose: if it % 20 == 0: - print('{:5s}|{:12s}|{:8s}'.format( - 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32) + print('{:5s}|{:12s}|{:8s}|{:8s}'.format( + 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: -- cgit v1.2.3