summaryrefslogtreecommitdiff
path: root/examples/plot_barycenter_fgw.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /examples/plot_barycenter_fgw.py
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'examples/plot_barycenter_fgw.py')
-rw-r--r--examples/plot_barycenter_fgw.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
index 9eea036..e4be447 100644
--- a/examples/plot_barycenter_fgw.py
+++ b/examples/plot_barycenter_fgw.py
@@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
colors.append(val_map[node])
return colors
-#%% create dataset
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
# We build a dataset of noisy circular graphs.
# Noise is added on the structures by random connections and on the features by gaussian noise.
@@ -135,7 +139,11 @@ X0 = []
for k in range(9):
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
-#%% Plot dataset
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
plt.figure(figsize=(8, 10))
for i in range(len(X0)):
@@ -146,9 +154,11 @@ for i in range(len(X0)):
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
plt.show()
+##############################################################################
+# Barycenter computation
+# ----------------------
-#%%
-# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
# Features distances are the euclidean distances
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
@@ -156,14 +166,16 @@ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()])
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
sizebary = 15 # we choose a barycenter with 15 nodes
-#%%
-
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
-#%%
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
-for i in range(len(A.ravel())):
- bary.add_node(i, attr_name=float(A.ravel()[i]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
#%%
pos = nx.kamada_kawai_layout(bary)