diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-10-27 08:41:08 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-27 08:41:08 +0200 |
commit | d7554331fc409fea48ee758fd630909dd9dc4827 (patch) | |
tree | 9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05 /ot/gromov.py | |
parent | 76450dddf8dd62b9714b72e99ae075516246d433 (diff) |
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function
* more tests pass
* more tests pass
* it works but not by default yet
* remove warningd
* update circleci doc
* update circleci doc
* new sinkhorn implemeted but not by default
* better
* doctest pass
* test doctest
* new test utils
* remove pep8 errors
* remove pep8 errors
* doc new implementtaion with log
* test sinkhorn 2
* doc for log implementation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r-- | ot/gromov.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index 85b1549..33b4453 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
@@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
|