diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-03-10 13:08:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-10 13:08:25 +0100 |
commit | 8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (patch) | |
tree | c814908f038cef1e2945850cb905bf9ba6c06c23 | |
parent | a5930d3b3a446bf860d6dfacc1e17151fae1dd1d (diff) |
[WIP] Fix gromov examples gallery (#444)
* maj gw/ srgw/ generic cg solver
* correct pep8 on current state
* fix bug previous tests
* fix pep8
* fix bug srGW constC in loss and gradient
* fix doc html
* fix doc html
* start updating test_optim.py
* update tests gromov and optim - plus fix gromov dependencies
* add symmetry feature to entropic gw
* add symmetry feature to entropic gw
* add exemple for sr(F)GW matchings
* small stuff
* remove (reg,M) from line-search/ complete srgw tests with backend
* remove backend repetitions / rename fG to costG/ fix innerlog to True
* fix pep8
* take comments into account / new nx parameters still to test
* factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions
* split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions
* manual documentaion gromov
* remove circular autosummary
* trying stuff
* debug documentation
* alphabetic ordering of module
* merge into branch
* add note in entropic gw solvers
* fix exemples/gromov doc
* add fixed issue to releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 2 | ||||
-rwxr-xr-x | examples/gromov/plot_gromov_wasserstein_dictionary_learning.py | 53 | ||||
-rw-r--r-- | examples/gromov/plot_semirelaxed_fgw.py | 41 |
3 files changed, 50 insertions, 46 deletions
diff --git a/RELEASES.md b/RELEASES.md index b51409b..da4d7bb 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -46,7 +46,7 @@ PR #413) - Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) - +- Fixed an issue with the documentation gallery section (PR #444) ## 0.8.2 diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py index 1fdc3b9..7585944 100755 --- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic import ot import networkx from networkx.generators.community import stochastic_block_model as sbm -# %% -# ============================================================================= + +############################################################################# +# # Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. -# ============================================================================= +# --------------------------------------------- np.random.seed(42) @@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Estimate the gromov-wasserstein dictionary from the dataset -# ============================================================================= +# --------------------------------------------- np.random.seed(0) @@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- # Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) @@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW): pl.axis("off") pl.tight_layout() pl.show() -#%% -# ============================================================================= + +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] @@ -211,11 +213,11 @@ pl.axis('off') pl.legend(fontsize=11) pl.tight_layout() pl.show() -# %% -# ============================================================================= -# Endow the dataset with node features -# ============================================================================= +############################################################################# +# +# Endow the dataset with node features +# --------------------------------------------- # We follow this feature assignment on all nodes of a graph depending on its label/number of clusters # 1 cluster --> 0 as nodes feature # 2 clusters --> 1 as nodes feature @@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters): pl.axis("off") pl.tight_layout() pl.show() -# %% -# ============================================================================= + +############################################################################# +# # Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs -# ============================================================================= +# --------------------------------------------- np.random.seed(0) ps = [ot.unif(C.shape[0]) for C in dataset] D = 3 # 6 atoms instead of 3 @@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the estimated dictionary atoms -# ============================================================================= +# --------------------------------------------- pl.figure(7, (12, 8)) pl.clf() @@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Visualization of the embedding space -# ============================================================================= +# --------------------------------------------- unmixings = [] reconstruction_errors = [] diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py index 8f879d4..ef4b286 100644 --- a/examples/gromov/plot_semirelaxed_fgw.py +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -31,10 +31,10 @@ from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_w import networkx from networkx.generators.community import stochastic_block_model as sbm -# %% -# ============================================================================= +############################################################################# +# # Generate two graphs following Stochastic Block models of 2 and 3 clusters. -# ============================================================================= +# --------------------------------------------- N2 = 20 # 2 communities @@ -81,10 +81,11 @@ for i, j in G3.edges(): weightedG3.add_edge(i, j, weight=weight_intra_G3) else: weightedG3.add_edge(i, j, weight=weight_inter_G3) -# %% -# ============================================================================= + +############################################################################# +# # Compute their semi-relaxed Gromov-Wasserstein divergences -# ============================================================================= +# --------------------------------------------- # 0) GW(C2, h2, C3, h3) for reference OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) @@ -106,11 +107,11 @@ print('srGW(C2, h2, C3) = ', srgw_23) print('srGW(C3, h3, C2) = ', srgw_32) -# %% -# ============================================================================= +############################################################################# +# # Visualization of the semi-relaxed Gromov-Wasserstein matchings -# ============================================================================= - +# --------------------------------------------- +# # We color nodes of the graph on the right - then project its node colors # based on the optimal transport plan from the srGW matching @@ -222,10 +223,10 @@ pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Add node features -# ============================================================================= +# --------------------------------------------- # We add node features with given mean - by clusters # and inversely proportional to clusters' intra-connectivity @@ -238,10 +239,10 @@ F3 = np.zeros((N3, 1)) for i, c in enumerate(part_G3): F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) -# %% -# ============================================================================= +############################################################################# +# # Compute their semi-relaxed Fused Gromov-Wasserstein divergences -# ============================================================================= +# --------------------------------------------- alpha = 0.5 # Compute pairwise euclidean distance between node features @@ -268,11 +269,11 @@ print('FGW(C2, F2, C3, F3) = ', fgw) print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) -# %% -# ============================================================================= +############################################################################# +# # Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings -# ============================================================================= - +# --------------------------------------------- +# # We color nodes of the graph on the right - then project its node colors # based on the optimal transport plan from the srFGW matching # NB: colors refer to clusters - not to node features |