diff options
Diffstat (limited to 'examples/gromov/plot_semirelaxed_fgw.py')
-rw-r--r-- | examples/gromov/plot_semirelaxed_fgw.py | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py index 8f879d4..ef4b286 100644 --- a/examples/gromov/plot_semirelaxed_fgw.py +++ b/examples/gromov/plot_semirelaxed_fgw.py @@ -31,10 +31,10 @@ from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_w import networkx from networkx.generators.community import stochastic_block_model as sbm -# %% -# ============================================================================= +############################################################################# +# # Generate two graphs following Stochastic Block models of 2 and 3 clusters. -# ============================================================================= +# --------------------------------------------- N2 = 20 # 2 communities @@ -81,10 +81,11 @@ for i, j in G3.edges(): weightedG3.add_edge(i, j, weight=weight_intra_G3) else: weightedG3.add_edge(i, j, weight=weight_inter_G3) -# %% -# ============================================================================= + +############################################################################# +# # Compute their semi-relaxed Gromov-Wasserstein divergences -# ============================================================================= +# --------------------------------------------- # 0) GW(C2, h2, C3, h3) for reference OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) @@ -106,11 +107,11 @@ print('srGW(C2, h2, C3) = ', srgw_23) print('srGW(C3, h3, C2) = ', srgw_32) -# %% -# ============================================================================= +############################################################################# +# # Visualization of the semi-relaxed Gromov-Wasserstein matchings -# ============================================================================= - +# --------------------------------------------- +# # We color nodes of the graph on the right - then project its node colors # based on the optimal transport plan from the srGW matching @@ -222,10 +223,10 @@ pl.tight_layout() pl.show() -# %% -# ============================================================================= +############################################################################# +# # Add node features -# ============================================================================= +# --------------------------------------------- # We add node features with given mean - by clusters # and inversely proportional to clusters' intra-connectivity @@ -238,10 +239,10 @@ F3 = np.zeros((N3, 1)) for i, c in enumerate(part_G3): F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) -# %% -# ============================================================================= +############################################################################# +# # Compute their semi-relaxed Fused Gromov-Wasserstein divergences -# ============================================================================= +# --------------------------------------------- alpha = 0.5 # Compute pairwise euclidean distance between node features @@ -268,11 +269,11 @@ print('FGW(C2, F2, C3, F3) = ', fgw) print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23) print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32) -# %% -# ============================================================================= +############################################################################# +# # Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings -# ============================================================================= - +# --------------------------------------------- +# # We color nodes of the graph on the right - then project its node colors # based on the optimal transport plan from the srFGW matching # NB: colors refer to clusters - not to node features |