summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gromov/plot_gromov_wasserstein_dictionary_learning.py')
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py53
1 files changed, 28 insertions, 25 deletions
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
index 1fdc3b9..7585944 100755
--- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
np.random.seed(42)
@@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Estimate the gromov-wasserstein dictionary from the dataset
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
@@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW):
pl.axis("off")
pl.tight_layout()
pl.show()
-#%%
-# =============================================================================
+
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
@@ -211,11 +213,11 @@ pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
-# Endow the dataset with node features
-# =============================================================================
+#############################################################################
+#
+# Endow the dataset with node features
+# ---------------------------------------------
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
# 1 cluster --> 0 as nodes feature
# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters):
pl.axis("off")
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset]
D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
pl.figure(7, (12, 8))
pl.clf()
@@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []