summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorNemo Fournier <nemo.fournier@ens-lyon.org>2020-03-09 11:09:54 +0100
committerNemo Fournier <nemo.fournier@ens-lyon.org>2020-03-09 11:09:54 +0100
commit11733534208fecbabae7b707c7b0965c9da1c752 (patch)
treeafcc093e676dbdc694a2e701c5234ff9d3d97af8 /ot/gromov.py
parent0baf83bbff6bd0c67244b3019509fe7518fb2d75 (diff)
fix fgw alpha parameter implementation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 9869341..7ad7e59 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -493,11 +493,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -573,7 +573,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns