summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-28 16:50:00 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-28 16:50:00 +0200
commitb1b514f5d9de009e63bd407dfd9c0a0cf6128876 (patch)
tree1a0b9e972d09af049d10bfbab30f10f3b657487f /examples
parent549b95b5736b42f3fe74daf9805303a08b1ae01d (diff)
bary fgw
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_barycenter_fgw.py172
-rw-r--r--examples/plot_fgw.py1
2 files changed, 172 insertions, 1 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..f416629
--- /dev/null
+++ b/examples/plot_barycenter_fgw.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+def find_thresh(C,inf=0.5,sup=3,step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist=[]
+ search=np.linspace(inf,sup,step)
+ for thresh in search:
+ Cprime=sp_to_adjency(C,0,thresh)
+ SC=shortest_path(Cprime,method='D')
+ SC[SC==float('inf')]=100
+ dist.append(np.linalg.norm(SC-C))
+ return search[np.argmin(dist)],dist
+
+def sp_to_adjency(C,threshinf=0.2,threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H=np.zeros_like(C)
+ np.fill_diagonal(H,np.diagonal(C))
+ C=C-H
+ C=np.minimum(np.maximum(C,threshinf),threshsup)
+ C[C==threshsup]=0
+ C[C!=0]=1
+
+ return C
+
+def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None):
+ """ Create a noisy circular graph
+ """
+ g=nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise=float(np.random.normal(mu,sigma,1))
+ if with_noise:
+ g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise)
+ else:
+ g.add_node(i,attr_name=math.sin(2*i*math.pi/N))
+ g.add_edge(i,i+1)
+ if structure_noise:
+ randomint=np.random.randint(0,p)
+ if randomint==0:
+ if i<=N-3:
+ g.add_edge(i,i+2)
+ if i==N-2:
+ g.add_edge(i,0)
+ if i==N-1:
+ g.add_edge(i,1)
+ g.add_edge(N,0)
+ noise=float(np.random.normal(mu,sigma,1))
+ if with_noise:
+ g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise)
+ else:
+ g.add_node(N,attr_name=math.sin(2*N*math.pi/N))
+ return g
+
+def graph_colors(nx_graph,vmin=0,vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin,vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k,v in nx.get_node_attributes(nx_graph,'attr_name').items():
+ val_map[k]=cpick.to_rgba(v)
+ colors=[]
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+#%% create dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+np.random.seed(30)
+X0=[]
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3))
+
+#%% Plot dataset
+
+plt.figure(figsize=(8,10))
+for i in range(len(X0)):
+ plt.subplot(3,3,i+1)
+ g=X0[i]
+ pos=nx.kamada_kawai_layout(g)
+ nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20)
+plt.show()
+
+
+
+#%%
+# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0]
+Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0]
+lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel()
+sizebary=15 # we choose a barycenter with 15 nodes
+
+#%%
+
+A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95)
+
+#%%
+bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0]))
+for i in range(len(A.ravel())):
+ bary.add_node(i,attr_name=float(A.ravel()[i]))
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False)
+plt.suptitle('Barycenter',fontsize=20)
+plt.show()
+
+
+
+
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
index 5c2d0e1..bfa7fb4 100644
--- a/examples/plot_fgw.py
+++ b/examples/plot_fgw.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
==============================