summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_barycenter_fgw.py172
-rw-r--r--examples/plot_fgw.py151
2 files changed, 323 insertions, 0 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..f416629
--- /dev/null
+++ b/examples/plot_barycenter_fgw.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+def find_thresh(C,inf=0.5,sup=3,step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist=[]
+ search=np.linspace(inf,sup,step)
+ for thresh in search:
+ Cprime=sp_to_adjency(C,0,thresh)
+ SC=shortest_path(Cprime,method='D')
+ SC[SC==float('inf')]=100
+ dist.append(np.linalg.norm(SC-C))
+ return search[np.argmin(dist)],dist
+
+def sp_to_adjency(C,threshinf=0.2,threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H=np.zeros_like(C)
+ np.fill_diagonal(H,np.diagonal(C))
+ C=C-H
+ C=np.minimum(np.maximum(C,threshinf),threshsup)
+ C[C==threshsup]=0
+ C[C!=0]=1
+
+ return C
+
+def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None):
+ """ Create a noisy circular graph
+ """
+ g=nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise=float(np.random.normal(mu,sigma,1))
+ if with_noise:
+ g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise)
+ else:
+ g.add_node(i,attr_name=math.sin(2*i*math.pi/N))
+ g.add_edge(i,i+1)
+ if structure_noise:
+ randomint=np.random.randint(0,p)
+ if randomint==0:
+ if i<=N-3:
+ g.add_edge(i,i+2)
+ if i==N-2:
+ g.add_edge(i,0)
+ if i==N-1:
+ g.add_edge(i,1)
+ g.add_edge(N,0)
+ noise=float(np.random.normal(mu,sigma,1))
+ if with_noise:
+ g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise)
+ else:
+ g.add_node(N,attr_name=math.sin(2*N*math.pi/N))
+ return g
+
+def graph_colors(nx_graph,vmin=0,vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin,vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k,v in nx.get_node_attributes(nx_graph,'attr_name').items():
+ val_map[k]=cpick.to_rgba(v)
+ colors=[]
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+#%% create dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+np.random.seed(30)
+X0=[]
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3))
+
+#%% Plot dataset
+
+plt.figure(figsize=(8,10))
+for i in range(len(X0)):
+ plt.subplot(3,3,i+1)
+ g=X0[i]
+ pos=nx.kamada_kawai_layout(g)
+ nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20)
+plt.show()
+
+
+
+#%%
+# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0]
+Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0]
+lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel()
+sizebary=15 # we choose a barycenter with 15 nodes
+
+#%%
+
+A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95)
+
+#%%
+bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0]))
+for i in range(len(A.ravel())):
+ bary.add_node(i,attr_name=float(A.ravel()[i]))
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False)
+plt.suptitle('Barycenter',fontsize=20)
+plt.show()
+
+
+
+
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
new file mode 100644
index 0000000..bfa7fb4
--- /dev/null
+++ b/examples/plot_fgw.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import matplotlib.pyplot as pl
+import numpy as np
+import ot
+from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein
+
+#%% parameters
+# We create two 1D random measures
+n=20
+n2=30
+sig=1
+sig2=0.1
+
+np.random.seed(0)
+
+phi=np.arange(n)[:,None]
+xs=phi+sig*np.random.randn(n,1)
+ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1)
+
+phi2=np.arange(n2)[:,None]
+xt=phi2+sig*np.random.randn(n2,1)
+yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1)
+yt= yt[::-1,:]
+
+p=ot.unif(n)
+q=ot.unif(n2)
+
+#%% plot the distributions
+
+pl.close(10)
+pl.figure(10,(7,7))
+
+pl.subplot(2,1,1)
+
+pl.scatter(ys,xs,c=phi,s=70)
+pl.ylabel('Feature value a',fontsize=20)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1)
+pl.xticks(())
+pl.yticks(())
+pl.subplot(2,1,2)
+pl.scatter(yt,xt,c=phi2,s=70)
+pl.xlabel('coordinates x/y',fontsize=25)
+pl.ylabel('Feature value b',fontsize=20)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1)
+pl.yticks(())
+pl.tight_layout()
+pl.show()
+
+
+#%% Structure matrices and across-features distance matrix
+C1=ot.dist(xs)
+C2=ot.dist(xt).T
+M=ot.dist(ys,yt)
+w1=ot.unif(C1.shape[0])
+w2=ot.unif(C2.shape[0])
+Got=ot.emd([],[],M)
+
+#%%
+cmap='Reds'
+pl.close(10)
+pl.figure(10,(5,5))
+fs=15
+l_x=[0,5,10,15]
+l_y=[0,5,10,15,20,25]
+gs = pl.GridSpec(5, 5)
+
+ax1=pl.subplot(gs[3:,:2])
+
+pl.imshow(C1,cmap=cmap,interpolation='nearest')
+pl.title("$C_1$",fontsize=fs)
+pl.xlabel("$k$",fontsize=fs)
+pl.ylabel("$i$",fontsize=fs)
+pl.xticks(l_x)
+pl.yticks(l_x)
+
+ax2=pl.subplot(gs[:3,2:])
+
+pl.imshow(C2,cmap=cmap,interpolation='nearest')
+pl.title("$C_2$",fontsize=fs)
+pl.ylabel("$l$",fontsize=fs)
+#pl.ylabel("$l$",fontsize=fs)
+pl.xticks(())
+pl.yticks(l_y)
+ax2.set_aspect('auto')
+
+ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1)
+pl.imshow(M,cmap=cmap,interpolation='nearest')
+pl.yticks(l_x)
+pl.xticks(l_y)
+pl.ylabel("$i$",fontsize=fs)
+pl.title("$M_{AB}$",fontsize=fs)
+pl.xlabel("$j$",fontsize=fs)
+pl.tight_layout()
+ax3.set_aspect('auto')
+pl.show()
+
+
+#%% Computing FGW and GW
+alpha=1e-3
+
+ot.tic()
+Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True)
+ot.toc()
+
+#%reload_ext WGW
+Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True)
+
+#%% visu OT matrix
+cmap='Blues'
+fs=15
+pl.figure(2,(13,5))
+pl.clf()
+pl.subplot(1,3,1)
+pl.imshow(Got,cmap=cmap,interpolation='nearest')
+#pl.xlabel("$y$",fontsize=fs)
+pl.ylabel("$i$",fontsize=fs)
+pl.xticks(())
+
+pl.title('Wasserstein ($M$ only)')
+
+pl.subplot(1,3,2)
+pl.imshow(Gg,cmap=cmap,interpolation='nearest')
+pl.title('Gromov ($C_1,C_2$ only)')
+pl.xticks(())
+pl.subplot(1,3,3)
+pl.imshow(Gwg,cmap=cmap,interpolation='nearest')
+pl.title('FGW ($M+C_1,C_2$)')
+
+pl.xlabel("$j$",fontsize=fs)
+pl.ylabel("$i$",fontsize=fs)
+
+pl.tight_layout()
+pl.show() \ No newline at end of file