summaryrefslogtreecommitdiff
path: root/ot/optim.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 11:00:09 +0100
committerGitHub <noreply@github.com>2021-11-04 11:00:09 +0100
commit2fe69eb130827560ada704bc25998397c4357821 (patch)
tree82973444cc4afc4c42cc7cdaf43a2ebd4b1a6a91 /ot/optim.py
parent9c6ac880d426b7577918b0c77bd74b3b01930ef6 (diff)
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints
Diffstat (limited to 'ot/optim.py')
-rw-r--r--ot/optim.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/ot/optim.py b/ot/optim.py
index cc286b6..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
@@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G