From 2a32e2ea64d0d5096953a9b8259b0507fa58dca5 Mon Sep 17 00:00:00 2001 From: Kilian Date: Wed, 13 Nov 2019 13:55:24 +0100 Subject: fix log bug in gromov_wasserstein2 --- ot/gromov.py | 156 +++++++++++++++++++++++++++++------------------------------ 1 file changed, 77 insertions(+), 79 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 699ae4c..9869341 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - - H : entropy Parameters ---------- @@ -343,6 +342,83 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): + """ + Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + + The function solves the following optimization problem: + + .. math:: + GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space. + q : ndarray, shape (nt,) + Distribution in the target space. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + + Returns + ------- + gw_dist : float + Gromov-Wasserstein distance + log : dict + convergence information and Coupling marix + + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the + metric approach to object matching. Foundations of computational + mathematics 11.4 (2011): 417-487. + + """ + + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + + G0 = p[:, None] * q[None, :] + + def f(G): + return gwloss(constC, hC1, hC2, G) + + def df(G): + return gwggrad(constC, hC1, hC2, G) + res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) + log_gw['T'] = res + if log: + return log_gw['gw_dist'], log_gw + else: + return log_gw['gw_dist'] + + def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): """ Computes the FGW transport between two graphs see [24] @@ -506,84 +582,6 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): - """ - Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) - - The function solves the following optimization problem: - - .. math:: - GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - Where : - - C1 : Metric cost matrix in the source space - - C2 : Metric cost matrix in the target space - - p : distribution in the source space - - q : distribution in the target space - - L : loss function to account for the misfit between the similarity matrices - - H : entropy - - Parameters - ---------- - C1 : ndarray, shape (ns, ns) - Metric cost matrix in the source space - C2 : ndarray, shape (nt, nt) - Metric cost matrix in the target space - p : ndarray, shape (ns,) - Distribution in the source space. - q : ndarray, shape (nt,) - Distribution in the target space. - loss_fun : str - loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - - Returns - ------- - gw_dist : float - Gromov-Wasserstein distance - log : dict - convergence information and Coupling marix - - References - ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the - metric approach to object matching. Foundations of computational - mathematics 11.4 (2011): 417-487. - - """ - - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - - G0 = p[:, None] * q[None, :] - - def f(G): - return gwloss(constC, hC1, hC2, G) - - def df(G): - return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - log['T'] = res - if log: - return log['gw_dist'], log - else: - return log['gw_dist'] - - def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): """ -- cgit v1.2.3 From 11733534208fecbabae7b707c7b0965c9da1c752 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Mon, 9 Mar 2020 11:09:54 +0100 Subject: fix fgw alpha parameter implementation --- ot/gromov.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 9869341..7ad7e59 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -493,11 +493,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) log['fgw_dist'] = log['loss'][::-1][0] return res, log else: - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -573,7 +573,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) if log: log['fgw_dist'] = log['loss'][::-1][0] log['T'] = res @@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, + T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns -- cgit v1.2.3 From 20f9abd8633f4a905df97cc5478eae2e53c1aa96 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Mon, 9 Mar 2020 11:38:19 +0100 Subject: clean and complete the document of fgw related functions --- ot/gromov.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 7ad7e59..e329c70 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [24]_ @@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Distribution in the target space loss_fun : str, optional Loss function used for the solver - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + record log if True **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ @@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Distribution in the target space. loss_fun : str, optional Loss function used for the solver. - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + Record log if True. **kwargs : dict Parameters can be directly pased to the ot.optim.cg solver. @@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Whether to fix the structure of the barycenter during the updates fixed_features : bool Whether to fix the feature of the barycenter during the updates + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshol on error (>0). + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. init_C : ndarray, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. -- cgit v1.2.3 From 18fa98fb109c935dc8d87f9c93318d8cfd118738 Mon Sep 17 00:00:00 2001 From: Nemo Fournier Date: Tue, 10 Mar 2020 15:57:41 +0100 Subject: fixing trailing and before arithmetic operation whitespace issues --- ot/gromov.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index e329c70..43780a4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -488,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) log['fgw_dist'] = log['loss'][::-1][0] return res, log else: - return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -563,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) if log: log['fgw_dist'] = log['loss'][::-1][0] log['T'] = res @@ -987,13 +987,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' max_iter : int, optional - Max number of iterations + Max number of iterations tol : float, optional Stop threshol on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional - Record log if True. + Record log if True. init_C : ndarray, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. -- cgit v1.2.3 From ad7aa892b47f039366a30103c1cede804811fb46 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 24 Apr 2020 16:39:29 +0200 Subject: better doc per module --- ot/__init__.py | 30 ------------------------------ ot/bregman.py | 2 +- ot/datasets.py | 2 +- ot/dr.py | 4 ++-- ot/gpu/__init__.py | 5 +++-- ot/gromov.py | 2 +- ot/optim.py | 2 +- ot/partial.py | 2 +- ot/smooth.py | 4 +++- ot/unbalanced.py | 2 +- 10 files changed, 14 insertions(+), 41 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/__init__.py b/ot/__init__.py index 1e57b78..2d23610 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,35 +1,5 @@ """ -This is the main module of the POT toolbox. It provides easy access to -a number of sub-modules and functions described below. - -.. note:: - - - Here is a list of the submodules and short description of what they contain. - - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.bregman` contains OT solvers for the entropic OT problems using - Bregman projections. - - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems. - - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT - problems. - - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov - Wasserstein problems. - - :any:`ot.optim` contains generic solvers OT based optimization problems - - :any:`ot.da` contains classes and function related to Monge mapping - estimation and Domain Adaptation (DA). - - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers - - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein - Discriminant Analysis. - - :any:`ot.utils` contains utility functions such as distance computation and - timing. - - :any:`ot.datasets` contains toy dataset generation functions. - - :any:`ot.plot` contains visualization functions - - :any:`ot.stochastic` contains stochastic solvers for regularized OT. - - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT. - - :any:`ot.partial` contains solvers for partial OT. - .. warning:: The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` diff --git a/ot/bregman.py b/ot/bregman.py index b4365d0..f1f8437 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Bregman projections for regularized OT +Bregman projections solvers for entropic regularized OT """ # Author: Remi Flamary diff --git a/ot/datasets.py b/ot/datasets.py index daca1ae..bd39e13 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -1,5 +1,5 @@ """ -Simple example datasets for OT +Simple example datasets """ # Author: Remi Flamary diff --git a/ot/dr.py b/ot/dr.py index 680dabf..11d2e10 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -Dimension reduction with optimal transport +Dimension reduction with OT .. warning:: - Note that by default the module is not import in :mod:`ot`. In order to + Note that by default the module is not imported in :mod:`ot`. In order to use it you need to explicitely import :mod:`ot.dr` """ diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 1ab95bb..7478fb9 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """ +GPU implementation for several OT solvers and utility +functions. -This module provides GPU implementation for several OT solvers and utility -functions. The GPU backend in handled by `cupy +The GPU backend in handled by `cupy `_. .. warning:: diff --git a/ot/gromov.py b/ot/gromov.py index 43780a4..a678722 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Gromov-Wasserstein transport method +Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers """ # Author: Erwan Vautier diff --git a/ot/optim.py b/ot/optim.py index 4012e0d..b9ca891 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Optimization algorithms for OT +Generic solvers for regularized OT """ # Author: Remi Flamary diff --git a/ot/partial.py b/ot/partial.py index c03ec25..eb707d8 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Partial OT +Partial OT solvers """ # Author: Laetitia Chapel diff --git a/ot/smooth.py b/ot/smooth.py index 5a8e4b5..81f6a3e 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -26,7 +26,9 @@ # Remi Flamary """ -Implementation of +Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) + +Implementation of : Smooth and Sparse Optimal Transport. Mathieu Blondel, Vivien Seguy, Antoine Rolet. In Proc. of AISTATS 2018. diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 23f6607..e37f10c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Regularized Unbalanced OT +Regularized Unbalanced OT solvers """ # Author: Hicham Janati -- cgit v1.2.3 From 08ec66ede42350dd040b559cca181d2e599c3d2d Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 24 Apr 2020 16:41:56 +0200 Subject: pep8 --- ot/datasets.py | 2 +- ot/gromov.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/datasets.py b/ot/datasets.py index bd39e13..b86ef3b 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -1,5 +1,5 @@ """ -Simple example datasets +Simple example datasets """ # Author: Remi Flamary diff --git a/ot/gromov.py b/ot/gromov.py index a678722..4427a96 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers +Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers """ # Author: Erwan Vautier -- cgit v1.2.3