From b1b514f5d9de009e63bd407dfd9c0a0cf6128876 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 16:50:00 +0200 Subject: bary fgw --- examples/plot_barycenter_fgw.py | 172 ++++++++++++++++++++++++++++++++++++++++ examples/plot_fgw.py | 1 - ot/gromov.py | 15 ++-- 3 files changed, 180 insertions(+), 8 deletions(-) create mode 100644 examples/plot_barycenter_fgw.py diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py new file mode 100644 index 0000000..f416629 --- /dev/null +++ b/examples/plot_barycenter_fgw.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot graphs' barycenter using FGW +================================= + +This example illustrates the computation barycenter of labeled graphs using FGW + +Requires networkx >=2 + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer +# +# License: MIT License + +#%% load libraries +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import fgw_barycenters +#%% Graph functions + +def find_thresh(C,inf=0.5,sup=3,step=10): + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + and the original matrix. + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + inf : float + The beginning of the linesearch + sup : float + The end of the linesearch + step : integer + Number of thresholds tested + """ + dist=[] + search=np.linspace(inf,sup,step) + for thresh in search: + Cprime=sp_to_adjency(C,0,thresh) + SC=shortest_path(Cprime,method='D') + SC[SC==float('inf')]=100 + dist.append(np.linalg.norm(SC-C)) + return search[np.argmin(dist)],dist + +def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + threshinf : float + The minimum value of distance from which the new value is set to 1 + threshsup : float + The maximum value of distance from which the new value is set to 1 + Returns + ------- + C : ndarray, shape (n_nodes,n_nodes) + The threshold matrix. Each element is in {0,1} + """ + H=np.zeros_like(C) + np.fill_diagonal(H,np.diagonal(C)) + C=C-H + C=np.minimum(np.maximum(C,threshinf),threshsup) + C[C==threshsup]=0 + C[C!=0]=1 + + return C + +def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): + """ Create a noisy circular graph + """ + g=nx.Graph() + g.add_nodes_from(list(range(N))) + for i in range(N): + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) + else: + g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) + g.add_edge(i,i+1) + if structure_noise: + randomint=np.random.randint(0,p) + if randomint==0: + if i<=N-3: + g.add_edge(i,i+2) + if i==N-2: + g.add_edge(i,0) + if i==N-1: + g.add_edge(i,1) + g.add_edge(N,0) + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) + else: + g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) + return g + +def graph_colors(nx_graph,vmin=0,vmax=7): + cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') + cpick.set_array([]) + val_map = {} + for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): + val_map[k]=cpick.to_rgba(v) + colors=[] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + +#%% create dataset +# We build a dataset of noisy circular graphs. +# Noise is added on the structures by random connections and on the features by gaussian noise. + +np.random.seed(30) +X0=[] +for k in range(9): + X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) + +#%% Plot dataset + +plt.figure(figsize=(8,10)) +for i in range(len(X0)): + plt.subplot(3,3,i+1) + g=X0[i] + pos=nx.kamada_kawai_layout(g) + nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) +plt.show() + + + +#%% +# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +# Features distances are the euclidean distances +Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] +Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] +lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() +sizebary=15 # we choose a barycenter with 15 nodes + +#%% + +A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) + +#%% +bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) +for i in range(len(A.ravel())): + bary.add_node(i,attr_name=float(A.ravel()[i])) + +#%% +pos = nx.kamada_kawai_layout(bary) +nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) +plt.suptitle('Barycenter',fontsize=20) +plt.show() + + + + diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index 5c2d0e1..bfa7fb4 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ ============================== diff --git a/ot/gromov.py b/ot/gromov.py index 7491664..31bd657 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -883,8 +883,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, return C -def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss', - max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None): +def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False, + p=None,loss_fun='square_loss',max_iter=100, tol=1e-9, + verbose=False,log=True,init_C=None,init_X=None): """ Compute the fgw barycenter as presented eq (5) in [3]. @@ -957,7 +958,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature X=np.zeros((N,d)) else: X = init_X - + T=[np.outer(p,q) for q in ps] # X is N,d @@ -981,7 +982,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if not fixed_features: Ys_temp=[y.T for y in Ys] - X=update_feature_matrix(lambdas,Ys_temp,T,p) + X=update_feature_matrix(lambdas,Ys_temp,T,p).T # X must be N,d # Ys must be ns,d @@ -1024,11 +1025,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature print('{:5d}|{:8e}|'.format(cpt, err_feature)) cpt += 1 - log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys + log_['T']=T # from target to Ys log_['p']=p - log_['Ms']=Ms #Ms sont de tailles N,ns + log_['Ms']=Ms #Ms are N,ns - return X.T,C,log_ + return X,C,log_ def update_sructure_matrix(p, lambdas, T, Cs): -- cgit v1.2.3