From 11733534208fecbabae7b707c7b0965c9da1c752 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Mon, 9 Mar 2020 11:09:54 +0100 Subject: fix fgw alpha parameter implementation --- ot/gromov.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 9869341..7ad7e59 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -493,11 +493,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) log['fgw_dist'] = log['loss'][::-1][0] return res, log else: - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -573,7 +573,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) if log: log['fgw_dist'] = log['loss'][::-1][0] log['T'] = res @@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, + T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns -- cgit v1.2.3 From 20f9abd8633f4a905df97cc5478eae2e53c1aa96 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Mon, 9 Mar 2020 11:38:19 +0100 Subject: clean and complete the document of fgw related functions --- ot/gromov.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 7ad7e59..e329c70 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [24]_ @@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Distribution in the target space loss_fun : str, optional Loss function used for the solver - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + record log if True **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ @@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Distribution in the target space. loss_fun : str, optional Loss function used for the solver. - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + Record log if True. **kwargs : dict Parameters can be directly pased to the ot.optim.cg solver. @@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Whether to fix the structure of the barycenter during the updates fixed_features : bool Whether to fix the feature of the barycenter during the updates + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshol on error (>0). + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. init_C : ndarray, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. -- cgit v1.2.3 From 18fa98fb109c935dc8d87f9c93318d8cfd118738 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Tue, 10 Mar 2020 15:57:41 +0100 Subject: fixing trailing and before arithmetic operation whitespace issues --- ot/gromov.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index e329c70..43780a4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -488,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) log['fgw_dist'] = log['loss'][::-1][0] return res, log else: - return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -563,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) if log: log['fgw_dist'] = log['loss'][::-1][0] log['T'] = res @@ -987,13 +987,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' max_iter : int, optional - Max number of iterations + Max number of iterations tol : float, optional Stop threshol on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional - Record log if True. + Record log if True. init_C : ndarray, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. -- cgit v1.2.3