summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-03-10 13:08:25 +0100
committerGitHub <noreply@github.com>2023-03-10 13:08:25 +0100
commit8f56effe7320991ebdc6457a2cf1d3b6648a09d1 (patch)
treec814908f038cef1e2945850cb905bf9ba6c06c23
parenta5930d3b3a446bf860d6dfacc1e17151fae1dd1d (diff)
[WIP] Fix gromov examples gallery (#444)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py53
-rw-r--r--examples/gromov/plot_semirelaxed_fgw.py41
3 files changed, 50 insertions, 46 deletions
diff --git a/RELEASES.md b/RELEASES.md
index b51409b..da4d7bb 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -46,7 +46,7 @@ PR #413)
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
-
+- Fixed an issue with the documentation gallery section (PR #444)
## 0.8.2
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
index 1fdc3b9..7585944 100755
--- a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -45,10 +45,11 @@ from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dic
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
np.random.seed(42)
@@ -109,10 +110,10 @@ for idx_c, c in enumerate(clusters):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Estimate the gromov-wasserstein dictionary from the dataset
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
@@ -140,10 +141,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ for idx_atom, atom in enumerate(Cdict_GW):
pl.axis("off")
pl.tight_layout()
pl.show()
-#%%
-# =============================================================================
+
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
@@ -211,11 +213,11 @@ pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
-# Endow the dataset with node features
-# =============================================================================
+#############################################################################
+#
+# Endow the dataset with node features
+# ---------------------------------------------
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
# 1 cluster --> 0 as nodes feature
# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ for idx_c, c in enumerate(clusters):
pl.axis("off")
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
-# =============================================================================
+# ---------------------------------------------
np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset]
D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ pl.ylabel('loss', fontsize=12)
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the estimated dictionary atoms
-# =============================================================================
+# ---------------------------------------------
pl.figure(7, (12, 8))
pl.clf()
@@ -307,10 +310,10 @@ for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the embedding space
-# =============================================================================
+# ---------------------------------------------
unmixings = []
reconstruction_errors = []
diff --git a/examples/gromov/plot_semirelaxed_fgw.py b/examples/gromov/plot_semirelaxed_fgw.py
index 8f879d4..ef4b286 100644
--- a/examples/gromov/plot_semirelaxed_fgw.py
+++ b/examples/gromov/plot_semirelaxed_fgw.py
@@ -31,10 +31,10 @@ from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_w
import networkx
from networkx.generators.community import stochastic_block_model as sbm
-# %%
-# =============================================================================
+#############################################################################
+#
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
-# =============================================================================
+# ---------------------------------------------
N2 = 20 # 2 communities
@@ -81,10 +81,11 @@ for i, j in G3.edges():
weightedG3.add_edge(i, j, weight=weight_intra_G3)
else:
weightedG3.add_edge(i, j, weight=weight_inter_G3)
-# %%
-# =============================================================================
+
+#############################################################################
+#
# Compute their semi-relaxed Gromov-Wasserstein divergences
-# =============================================================================
+# ---------------------------------------------
# 0) GW(C2, h2, C3, h3) for reference
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
@@ -106,11 +107,11 @@ print('srGW(C2, h2, C3) = ', srgw_23)
print('srGW(C3, h3, C2) = ', srgw_32)
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the semi-relaxed Gromov-Wasserstein matchings
-# =============================================================================
-
+# ---------------------------------------------
+#
# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srGW matching
@@ -222,10 +223,10 @@ pl.tight_layout()
pl.show()
-# %%
-# =============================================================================
+#############################################################################
+#
# Add node features
-# =============================================================================
+# ---------------------------------------------
# We add node features with given mean - by clusters
# and inversely proportional to clusters' intra-connectivity
@@ -238,10 +239,10 @@ F3 = np.zeros((N3, 1))
for i, c in enumerate(part_G3):
F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
-# %%
-# =============================================================================
+#############################################################################
+#
# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
-# =============================================================================
+# ---------------------------------------------
alpha = 0.5
# Compute pairwise euclidean distance between node features
@@ -268,11 +269,11 @@ print('FGW(C2, F2, C3, F3) = ', fgw)
print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)
-# %%
-# =============================================================================
+#############################################################################
+#
# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
-# =============================================================================
-
+# ---------------------------------------------
+#
# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srFGW matching
# NB: colors refer to clusters - not to node features