summaryrefslogtreecommitdiff
path: root/ot/gromov
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-03-16 08:05:54 +0100
committerGitHub <noreply@github.com>2023-03-16 08:05:54 +0100
commit583501652517c4f1dbd8572e9f942551a9e54a1f (patch)
treefadb96f888924b2d1bef01b78486e97a88ebcd42 /ot/gromov
parent8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (diff)
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md * fix bugs of gw_entropic and armijo to run on gpu * add pr to releases.md * fix pep8 * fix call to backend in line_search_armijo * correct docstring generic_conditional_gradient --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/gromov')
-rw-r--r--ot/gromov/_bregman.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py
index 5b2f959..b0cccfb 100644
--- a/ot/gromov/_bregman.py
+++ b/ot/gromov/_bregman.py
@@ -11,9 +11,6 @@ Bregman projections solvers for entropic Gromov-Wasserstein
#
# License: MIT License
-import numpy as np
-
-
from ..bregman import sinkhorn
from ..utils import dist, list_to_array, check_random_state
from ..backend import get_backend
@@ -109,7 +106,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
T = G0
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
if symmetric is None:
- symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
if not symmetric:
constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
cpt = 0