From 8f56effe7320991ebdc6457a2cf1d3b6648a09d1 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Fri, 10 Mar 2023 13:08:25 +0100 Subject: [WIP] Fix gromov examples gallery (#444) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md --------- Co-authored-by: Rémi Flamary --- .../plot_gromov_wasserstein_dictionary_learning.py | 53 ++++++++++++---------- 1 file changed, 28 insertions(+), 25 deletions(-) (limited to 'examples/gromov/plot_gromov_wasserstein_dictionary_learning.py') diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py index 1fdc3b9..7585944 100755 --- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic import ot import networkx from networkx.generators.community import stochastic_block_model as sbm -# %% -# ============================================================================= + +############################################################################# +# # Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. -# ============================================================================= +# --------------------------------------------- np.random.seed(42) @@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Estimate the gromov-wasserstein dictionary from the dataset -# ============================================================================= +# --------------------------------------------- np.random.seed(0) @@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- # Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) @@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW): pl.axis("off") pl.tight_layout() pl.show() -#%% -# ============================================================================= + +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] @@ -211,11 +213,11 @@ pl.axis('off') pl.legend(fontsize=11) pl.tight_layout() pl.show() -# %% -# ============================================================================= -# Endow the dataset with node features -# ============================================================================= +############################################################################# +# +# Endow the dataset with node features +# --------------------------------------------- # We follow this feature assignment on all nodes of a graph depending on its label/number of clusters # 1 cluster --> 0 as nodes feature # 2 clusters --> 1 as nodes feature @@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters): pl.axis("off") pl.tight_layout() pl.show() -# %% -# ============================================================================= + +############################################################################# +# # Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs -# ============================================================================= +# --------------------------------------------- np.random.seed(0) ps = [ot.unif(C.shape[0]) for C in dataset] D = 3 # 6 atoms instead of 3 @@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- pl.figure(7, (12, 8)) pl.clf() @@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] -- cgit v1.2.3