diff options
author | Nemo Fournier <nemo.fournier@ens-lyon.org> | 2020-03-09 11:09:54 +0100 |
---|---|---|
committer | Nemo Fournier <nemo.fournier@ens-lyon.org> | 2020-03-09 11:09:54 +0100 |
commit | 11733534208fecbabae7b707c7b0965c9da1c752 (patch) | |
tree | afcc093e676dbdc694a2e701c5234ff9d3d97af8 /ot/gromov.py | |
parent | 0baf83bbff6bd0c67244b3019509fe7518fb2d75 (diff) |
fix fgw alpha parameter implementation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r-- | ot/gromov.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index 9869341..7ad7e59 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -493,11 +493,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -573,7 +573,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
|