summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-10-27 08:41:08 +0200
committerGitHub <noreply@github.com>2021-10-27 08:41:08 +0200
commitd7554331fc409fea48ee758fd630909dd9dc4827 (patch)
tree9b8ed4bf94c12d034d5fb1de5b7b5b76c23b4d05 /ot/gromov.py
parent76450dddf8dd62b9714b72e99ae075516246d433 (diff)
[WIP] Sinkhorn in log space (#290)
* adda sinkhorn log and working sinkhorn2 function * more tests pass * more tests pass * it works but not by default yet * remove warningd * update circleci doc * update circleci doc * new sinkhorn implemeted but not by default * better * doctest pass * test doctest * new test utils * remove pep8 errors * remove pep8 errors * doc new implementtaion with log * test sinkhorn 2 * doc for log implementation
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 85b1549..33b4453 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1030,7 +1030,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
# compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
@@ -1204,7 +1204,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)