diff options
author | tvayer <titouan.vayer@gmail.com> | 2019-05-29 17:05:38 +0200 |
---|---|---|
committer | tvayer <titouan.vayer@gmail.com> | 2019-05-29 17:05:48 +0200 |
commit | e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch) | |
tree | 1e85920b878ab715d211db56f99e25bfa2482fd3 /examples/plot_barycenter_fgw.py | |
parent | d4320382fa8873d15dcaec7adca3a4723c142515 (diff) |
code review1
Diffstat (limited to 'examples/plot_barycenter_fgw.py')
-rw-r--r-- | examples/plot_barycenter_fgw.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py index 9eea036..e4be447 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/plot_barycenter_fgw.py @@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7): colors.append(val_map[node]) return colors -#%% create dataset +############################################################################## +# Generate data +# ------------- + +#%% circular dataset # We build a dataset of noisy circular graphs. # Noise is added on the structures by random connections and on the features by gaussian noise. @@ -135,7 +139,11 @@ X0 = [] for k in range(9): X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) -#%% Plot dataset +############################################################################## +# Plot data +# --------- + +#%% Plot graphs plt.figure(figsize=(8, 10)) for i in range(len(X0)): @@ -146,9 +154,11 @@ for i in range(len(X0)): plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) plt.show() +############################################################################## +# Barycenter computation +# ---------------------- -#%% -# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph # Features distances are the euclidean distances Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] @@ -156,14 +166,16 @@ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]) lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() sizebary = 15 # we choose a barycenter with 15 nodes -#%% - A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95) -#%% +############################################################################## +# Plot Barycenter +# ------------------------- + +#%% Create the barycenter bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) -for i in range(len(A.ravel())): - bary.add_node(i, attr_name=float(A.ravel()[i])) +for i, v in enumerate(A.ravel()): + bary.add_node(i, attr_name=v) #%% pos = nx.kamada_kawai_layout(bary) |