From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- ot/__init__.py | 2 + ot/gromov.py | 141 ++++++++++++++++++++++++++++++++++++++++++++++----------- ot/optim.py | 3 +- 3 files changed, 118 insertions(+), 28 deletions(-) (limited to 'ot') diff --git a/ot/__init__.py b/ot/__init__.py index f20332c..4292b41 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/gromov.py b/ot/gromov.py index 465693d..ea667e4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): - """Return the Loss for Gromov-Wasserstein + r"""Return the Loss for Gromov-Wasserstein The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` @@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): - """Return the gradient for Gromov-Wasserstein + r"""Return the gradient for Gromov-Wasserstein The gradient is computed as described in Proposition 2 in :ref:`[12] ` @@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs if log: res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - return res, log + log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log else: - return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - :math:`\mathbf{p}`: distribution in the source space - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def df(G): return gwggrad(constC, hC1, hC2, G) - res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) - log_gw['T'] = res + + T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + T0 = nx.from_numpy(T, type_as=C10) + + log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) + log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) + log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) + log_gw['T'] = T0 + + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gw = nx.set_gradients(gw, (p0, q0, C10, C20), + (log_gw['u'], log_gw['v'], gC1, gC2)) + if log: - return log_gw['gw_dist'], log_gw + return gw, log_gw else: - return log_gw['gw_dist'] + return gw def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - log['fgw_dist'] = log['loss'][::-1][0] - return res, log + + fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) + + log['fgw_dist'] = fgw_dist + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: - return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + + fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) + + T0 = nx.from_numpy(T, type_as=C10) + + log_fgw['fgw_dist'] = fgw_dist + log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) + log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) + log_fgw['T'] = T0 + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), + (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + if log: - log['fgw_dist'] = log['loss'][::-1][0] - log['T'] = res - return log['fgw_dist'], log + return fgw_dist, log_fgw else: - return log['fgw_dist'] + return fgw_dist def GW_distance_estimation(C1, C2, p, q, loss_fun, T, @@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None): - """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- @@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_structure_matrix(p, lambdas, T, Cs): - """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration @@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" diff --git a/ot/optim.py b/ot/optim.py index cc286b6..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G -- cgit v1.2.3