diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-11-04 11:00:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-04 11:00:09 +0100 |
commit | 2fe69eb130827560ada704bc25998397c4357821 (patch) | |
tree | 82973444cc4afc4c42cc7cdaf43a2ebd4b1a6a91 /ot/optim.py | |
parent | 9c6ac880d426b7577918b0c77bd74b3b01930ef6 (diff) |
[MRG] Make gromov loss differentiable wrt matrices and weights (#302)
* grmov differentable
* new stuff
* test gromov gradients
* fgwdifferentiable
* fgw tested
* correc name test
* add awesome example with gromov optimizatrion
* pep8+ typos
* damn pep8
* thunbnail
* remove prints
Diffstat (limited to 'ot/optim.py')
-rw-r--r-- | ot/optim.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/ot/optim.py b/ot/optim.py index cc286b6..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G |