summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3 /examples
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_barycenter_fgw.py30
-rw-r--r--examples/plot_fgw.py32
2 files changed, 48 insertions, 14 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
index 9eea036..e4be447 100644
--- a/examples/plot_barycenter_fgw.py
+++ b/examples/plot_barycenter_fgw.py
@@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
colors.append(val_map[node])
return colors
-#%% create dataset
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
# We build a dataset of noisy circular graphs.
# Noise is added on the structures by random connections and on the features by gaussian noise.
@@ -135,7 +139,11 @@ X0 = []
for k in range(9):
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
-#%% Plot dataset
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
plt.figure(figsize=(8, 10))
for i in range(len(X0)):
@@ -146,9 +154,11 @@ for i in range(len(X0)):
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
plt.show()
+##############################################################################
+# Barycenter computation
+# ----------------------
-#%%
-# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
# Features distances are the euclidean distances
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
@@ -156,14 +166,16 @@ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()])
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
sizebary = 15 # we choose a barycenter with 15 nodes
-#%%
-
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
-#%%
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
-for i in range(len(A.ravel())):
- bary.add_node(i, attr_name=float(A.ravel()[i]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
#%%
pos = nx.kamada_kawai_layout(bary)
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
index ae3c487..43efc94 100644
--- a/examples/plot_fgw.py
+++ b/examples/plot_fgw.py
@@ -22,12 +22,16 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+##############################################################################
+# Generate data
+# ---------
+
#%% parameters
# We create two 1D random measures
-n = 20
-n2 = 30
-sig = 1
-sig2 = 0.1
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
np.random.seed(0)
@@ -43,6 +47,10 @@ yt = yt[::-1, :]
p = ot.unif(n)
q = ot.unif(n2)
+##############################################################################
+# Plot data
+# ---------
+
#%% plot the distributions
pl.close(10)
@@ -64,15 +72,22 @@ pl.yticks(())
pl.tight_layout()
pl.show()
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
-C2 = ot.dist(xt).T
+C2 = ot.dist(xt)
M = ot.dist(ys, yt)
w1 = ot.unif(C1.shape[0])
w2 = ot.unif(C2.shape[0])
Got = ot.emd([], [], M)
+##############################################################################
+# Plot matrices
+# ---------
+
#%%
cmap = 'Reds'
pl.close(10)
@@ -112,6 +127,9 @@ pl.tight_layout()
ax3.set_aspect('auto')
pl.show()
+##############################################################################
+# Compute FGW/GW
+# ---------
#%% Computing FGW and GW
alpha = 1e-3
@@ -123,6 +141,10 @@ ot.toc()
#%reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+##############################################################################
+# Visualize transport matrices
+# ---------
+
#%% visu OT matrix
cmap = 'Blues'
fs = 15