diff options
-rw-r--r-- | ot/gromov.py | 2 | ||||
-rw-r--r-- | ot/optim.py | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index 33134a2..44248d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -348,7 +348,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
diff --git a/ot/optim.py b/ot/optim.py index cbfb187..2170c7e 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -73,7 +73,7 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, def do_linesearch(cost, G, deltaG, Mi, f_val, - amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations Parameters |