summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_semirelaxed_fgw.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gromov/plot_semirelaxed_fgw.py')
-rw-r--r--examples/gromov/plot_semirelaxed_fgw.py41
1 files changed, 21 insertions, 20 deletions
diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py
index 8f879d4..ef4b286 100644
--- a/examples/gromov/plot_semirelaxed_fgw.py
+++ b/examples/gromov/plot_semirelaxed_fgw.py
@@ -31,10 +31,10 @@ from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_w
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+#############################################################################
+#
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
N2 = 20 # 2 communities
@@ -81,10 +81,11 @@ for i, j in G3.edges():
weightedG3.add_edge(i, j, weight=weight_intra_G3)
else:
weightedG3.add_edge(i, j, weight=weight_inter_G3)
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Compute their semi-relaxed Gromov-Wasserstein divergences
-# =============================================================================
+# ---------------------------------------------
# 0) GW(C2, h2, C3, h3) for reference
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
@@ -106,11 +107,11 @@ print('srGW(C2, h2, C3) = ', srgw_23)
print('srGW(C3, h3, C2) = ', srgw_32)
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the semi-relaxed Gromov-Wasserstein matchings
-# =============================================================================
-
+# ---------------------------------------------
+#
# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srGW matching
@@ -222,10 +223,10 @@ pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Add node features
-# =============================================================================
+# ---------------------------------------------
# We add node features with given mean - by clusters
# and inversely proportional to clusters' intra-connectivity
@@ -238,10 +239,10 @@ F3 = np.zeros((N3, 1))
for i, c in enumerate(part_G3):
F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
-# %%
-# =============================================================================
+#############################################################################
+#
# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
-# =============================================================================
+# ---------------------------------------------
alpha = 0.5
# Compute pairwise euclidean distance between node features
@@ -268,11 +269,11 @@ print('FGW(C2, F2, C3, F3) = ', fgw)
print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
-# =============================================================================
-
+# ---------------------------------------------
+#
# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srFGW matching
# NB: colors refer to clusters - not to node features