summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2020-03-19 13:56:46 +0100
committerGitHub <noreply@github.com>2020-03-19 13:56:46 +0100
commitfa06bb377d083c61f1ac0b067aeeab0fca2b5e7b (patch)
tree19359d03211d53ed5fb3fa0b1bf0a918548b30da /ot
parent599154c22f98eb7c0c5d3f97a6858c474b14dbdd (diff)
parentbb15cdd36aa1ea3e24d3fe36a9c49544c407fdfe (diff)
Merge pull request #133 from little-nem/fgw_fix
[MRG] Fix Fused Gromov Wasserstein implementation
Diffstat (limited to 'ot')
-rw-r--r--ot/gromov.py48
1 files changed, 24 insertions, 24 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 9869341..43780a4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
+ log : bool, optional
+ record log if True
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -493,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
where :
- M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
@@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research.
Else closed form is used. If there is convergence issues use False.
+ log : bool, optional
+ Record log if True.
**kwargs : dict
Parameters can be directly pased to the ot.optim.cg solver.
@@ -573,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0).
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
init_C : ndarray, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns