summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 2a70070..dc95c74 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1368,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1401,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-4, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1414,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1425,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
@@ -1479,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
@@ -1513,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -1526,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
err = nx.norm(C - Cprev)
error.append(err)
- if log:
- log['err'].append(err)
-
if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
@@ -1537,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
- return C
+ if log:
+ return C, {"err": error}
+ else:
+ return C
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,