From 2a32e2ea64d0d5096953a9b8259b0507fa58dca5 Mon Sep 17 00:00:00 2001 From: Kilian Date: Wed, 13 Nov 2019 13:55:24 +0100 Subject: fix log bug in gromov_wasserstein2 --- ot/gromov.py | 156 +++++++++++++++++++++++++++++------------------------------ 1 file changed, 77 insertions(+), 79 deletions(-) (limited to 'ot') diff --git a/ot/gromov.py b/ot/gromov.py index 699ae4c..9869341 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - - H : entropy Parameters ---------- @@ -343,6 +342,83 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): + """ + Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + + The function solves the following optimization problem: + + .. math:: + GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space. + q : ndarray, shape (nt,) + Distribution in the target space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + log : dict + convergence information and Coupling marix + + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) + log_gw['T'] = res + if log: + return log_gw['gw_dist'], log_gw + else: + return log_gw['gw_dist'] + + def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ Computes the FGW transport between two graphs see [24] @@ -506,84 +582,6 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) - - The function solves the following optimization problem: - - .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric cost matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space. - q : ndarray, shape (nt,) - Distribution in the target space. - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - log : dict - convergence information and Coupling marix - - References - ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - """ - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - G0 = p[:, None] * q[None, :] - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - log['T'] = res - if log: - return log['gw_dist'], log - else: - return log['gw_dist'] - - def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): """ -- cgit v1.2.3