diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-06-12 12:01:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 12:01:48 +0200 |
commit | 9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch) | |
tree | b7fda84880c5dabd1c441a1655741493e0683342 /examples/gromov | |
parent | f0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff) |
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests
* add exemples for entropic srgw - srfgw solvers
* add PPA solvers for GW/FGW + complete previous commits
* update readme
* add tests
* add examples + tests + warning in entropic solvers + releases
* reduce testing runtimes for test_gromov
* fix conflicts
* optional marginals
* improve coverage
* gromov doc harmonization
* fix pep8
* complete optional marginal for entropic srfgw
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples/gromov')
-rw-r--r-- | examples/gromov/plot_entropic_semirelaxed_fgw.py | 304 | ||||
-rw-r--r-- | examples/gromov/plot_fgw.py | 32 | ||||
-rw-r--r-- | examples/gromov/plot_fgw_solvers.py | 288 | ||||
-rw-r--r-- | examples/gromov/plot_gromov.py | 112 |
4 files changed, 693 insertions, 43 deletions
diff --git a/examples/gromov/plot_entropic_semirelaxed_fgw.py b/examples/gromov/plot_entropic_semirelaxed_fgw.py new file mode 100644 index 0000000..642baea --- /dev/null +++ b/examples/gromov/plot_entropic_semirelaxed_fgw.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +""" +========================== +Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example +========================== + +This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein +and the entropic semi-relaxed Fused Gromov-Wasserstein divergences. + +Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of +G2 at a minimal entropic-regularized (F)GW distance from G1. + +First, we generate two graphs following Stochastic Block Models, then show +how to compute their srGW matchings and illustrate them. These graphs are then +endowed with node features and we follow the same process with srFGW. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" +International Conference on Learning Representations (ICLR), 2021. +""" + +# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# --------------------------------------------- + + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) + + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) + +############################################################################# +# +# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences +# --------------------------------------------- + +# 0) GW(C2, h2, C3, h3) for reference +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) +gw = log['gw_dist'] + +# 1) srGW_e(C2, h2, C3) +OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein( + C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True) +srgw_23 = log_23['srgw_dist'] + +# 2) srGW_e(C3, h3, C2) + +OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein( + C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True) +srgw_32 = log_32['srgw_dist'] + +print('GW(C2, C3) = ', gw) +print('srGW_e(C2, h2, C3) = ', srgw_23) +print('srGW_e(C3, h3, C2) = ', srgw_32) + + +############################################################################# +# +# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the entropic srGW matching. +# We adjust the intensity of links across domains proportionaly to the mass +# sent, adding a minimal intensity of 0.1 if mass sent is not zero. + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, + p1, p2, T, pos1=None, pos2=None, + shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + for k1, v1 in pos1.items(): + max_Tk1 = np.max(T[k1, :]) + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 10 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() + +############################################################################# +# +# Add node features +# --------------------------------------------- + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +############################################################################# +# +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences +# --------------------------------------------- + +alpha = 0.5 +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference + +OT, log = fused_gromov_wasserstein( + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) +fgw = log['fgw_dist'] + +# 1) srFGW_e(C2, F2, h2, C3, F3) +OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein( + M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None) +srfgw_23 = log_23['srfgw_dist'] + +# 2) srFGW(C3, F3, h3, C2, F2) + +OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein( + M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None) +srfgw_32 = log_32['srfgw_dist'] + +print('FGW(C2, F2, C3, F3) = ', fgw) +print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23) +print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32) + +############################################################################# +# +# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srFGW matching +# NB: colors refer to clusters - not to node features + +pl.figure(2, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py index bf10de6..68ecb13 100644 --- a/examples/gromov/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -4,13 +4,13 @@ Plot Fused-Gromov-Wasserstein ============================== -This example illustrates the computation of FGW for 1D measures [18]. +This example first illustrates the computation of FGW for 1D measures estimated +using a Conditional Gradient solver [24]. -[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. - """ # Author: Titouan Vayer <titouan.vayer@irisa.fr> @@ -24,11 +24,13 @@ import numpy as np import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein + ############################################################################## # Generate data # ------------- -#%% parameters +# parameters + # We create two 1D random measures n = 20 # number of points in the first distribution n2 = 30 # number of points in the second distribution @@ -53,10 +55,9 @@ q = ot.unif(n2) # Plot data # --------- -#%% plot the distributions +# plot the distributions -pl.close(10) -pl.figure(10, (7, 7)) +pl.figure(1, (7, 7)) pl.subplot(2, 1, 1) @@ -78,7 +79,7 @@ pl.show() # Create structure matrices and across-feature distance matrix # ------------------------------------------------------------ -#%% Structure matrices and across-features distance matrix +# Structure matrices and across-features distance matrix C1 = ot.dist(xs) C2 = ot.dist(xt) M = ot.dist(ys, yt) @@ -90,10 +91,9 @@ Got = ot.emd([], [], M) # Plot matrices # ------------- -#%% cmap = 'Reds' -pl.close(10) -pl.figure(10, (5, 5)) + +pl.figure(2, (5, 5)) fs = 15 l_x = [0, 5, 10, 15] l_y = [0, 5, 10, 15, 20, 25] @@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:]) pl.imshow(C2, cmap=cmap, interpolation='nearest') pl.title("$C_2$", fontsize=fs) pl.ylabel("$l$", fontsize=fs) -#pl.ylabel("$l$",fontsize=fs) pl.xticks(()) pl.yticks(l_y) ax2.set_aspect('auto') @@ -133,28 +132,27 @@ pl.show() # Compute FGW/GW # -------------- -#%% Computing FGW and GW +# Computing FGW and GW alpha = 1e-3 ot.tic() Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) ot.toc() -#%reload_ext WGW +# reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) ############################################################################## # Visualize transport matrices # ---------------------------- -#%% visu OT matrix +# visu OT matrix cmap = 'Blues' fs = 15 -pl.figure(2, (13, 5)) +pl.figure(3, (13, 5)) pl.clf() pl.subplot(1, 3, 1) pl.imshow(Got, cmap=cmap, interpolation='nearest') -#pl.xlabel("$y$",fontsize=fs) pl.ylabel("$i$", fontsize=fs) pl.xticks(()) diff --git a/examples/gromov/plot_fgw_solvers.py b/examples/gromov/plot_fgw_solvers.py new file mode 100644 index 0000000..5f8a885 --- /dev/null +++ b/examples/gromov/plot_fgw_solvers.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +""" +============================== +Comparison of Fused Gromov-Wasserstein solvers +============================== + +This example illustrates the computation of FGW for attributed graphs +using 3 different solvers to estimate the distance based on Conditional +Gradient [24] or Sinkhorn projections [12, 51]. + +We generate two graphs following Stochastic Block Models further endowed with +node features and compute their FGW matchings. + +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), +"Gromov-Wasserstein averaging of kernel and distance matrices". +International Conference on Machine Learning (ICML). + +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). +"Gromov-wasserstein learning for graph matching and node embedding". +In International Conference on Machine Learning (ICML), 2019. +""" + +# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# --------------------------------------------- +np.random.seed(0) + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +############################################################################# +# +# Compute their Fused Gromov-Wasserstein distances +# --------------------------------------------- + +alpha = 0.5 + + +# Conditional Gradient algorithm +fgw0, log0 = fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True) + +# Proximal Point algorithm with Kullback-Leibler as proximal operator +fgw, log = entropic_fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA', + log=True, verbose=True, warmstart=False, numItermax=10) + +# Projected Gradient algorithm with entropic regularization +fgwe, loge = entropic_fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD', + log=True, verbose=True, warmstart=False, numItermax=10) + +print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist'])) +print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist'])) +print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist'])) + +# compute OT sparsity level +fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3) +fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3) +fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3) + +# Methods using Sinkhorn projections tend to produce feasibility errors on the +# marginal constraints + +err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3) +err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3) +erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3) + +############################################################################# +# +# Visualization of the Fused Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the FGW matchings +# We adjust the intensity of links across domains proportionaly to the mass +# sent, adding a minimal intensity of 0.1 if mass sent is not zero. +# For each matching, all node sizes are proportionnal to their mass computed +# from marginals of the OT plan to illustrate potential feasibility errors. +# NB: colors refer to clusters - not to node features + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T, + pos1=None, pos2=None, shiftx=4, switchx=False, + node_size=70, seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + + for k1, v1 in pos1.items(): + max_Tk1 = np.max(T[k1, :]) + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.7, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 13 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(2, figsize=(12, 3.5)) +pl.clf() +pl.subplot(131) +pl.axis('off') +pl.axis +pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %', + np.round(err0, 4)), fontsize=fontsize) + +p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + +pl.subplot(132) +pl.axis('off') + +p, q = fgw.sum(1), fgw.sum(0) # check marginals + +pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %', + np.round(err, 4)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw, + pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + +pl.subplot(133) +pl.axis('off') + +pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals + +pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %', + np.round(erre, 4)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe, + pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + +pl.tight_layout() + +pl.show() diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index afb5bdc..252267f 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -5,13 +5,38 @@ Gromov-Wasserstein example ==========================
This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
+We first compare 3 solvers to estimate the distance based on
+Conditional Gradient [24] or Sinkhorn projections [12, 51].
+Then we compare 2 stochastic solvers to estimate the distance with a lower
+numerical cost [33].
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+"Gromov-Wasserstein averaging of kernel and distance matrices".
+International Conference on Machine Learning (ICML).
+
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
+
+[33] Kerdoncuff T., Emonet R., Marc S. "Sampled Gromov Wasserstein",
+Machine Learning Journal (MJL), 2021.
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
+"Gromov-wasserstein learning for graph matching and node embedding".
+In International Conference on Machine Learning (ICML), 2019.
+
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+# Tanguy Kerdoncuff <tanguy.kerdoncuff@laposte.net>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 1
+
import scipy as sp
import numpy as np
import matplotlib.pylab as pl
@@ -36,7 +61,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
-
+np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
@@ -47,7 +72,7 @@ xt = np.random.randn(n_samples, 3).dot(P) + mu_t # --------------------------
-fig = pl.figure()
+fig = pl.figure(1)
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
@@ -66,11 +91,15 @@ C2 = sp.spatial.distance.cdist(xt, xt) C1 /= C1.max()
C2 /= C2.max()
-pl.figure()
+pl.figure(2)
pl.subplot(121)
pl.imshow(C1)
+pl.title('C1')
+
pl.subplot(122)
pl.imshow(C2)
+pl.title('C2')
+
pl.show()
#############################################################################
@@ -81,32 +110,63 @@ pl.show() p = ot.unif(n_samples)
q = ot.unif(n_samples)
+# Conditional Gradient algorithm
gw0, log0 = ot.gromov.gromov_wasserstein(
C1, C2, p, q, 'square_loss', verbose=True, log=True)
+# Proximal Point algorithm with Kullback-Leibler as proximal operator
gw, log = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-
-
-print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
-print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-
-
-pl.figure(1, (10, 5))
-
-pl.subplot(1, 2, 1)
-pl.imshow(gw0, cmap='jet')
-pl.title('Gromov Wasserstein')
-
-pl.subplot(1, 2, 2)
-pl.imshow(gw, cmap='jet')
-pl.title('Entropic Gromov Wasserstein')
-
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA',
+ log=True, verbose=True)
+
+# Projected Gradient algorithm with entropic regularization
+gwe, loge = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD',
+ log=True, verbose=True)
+
+print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist']))
+print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist']))
+print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist']))
+
+# compute OT sparsity level
+gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2)
+
+# Methods using Sinkhorn projections tend to produce feasibility errors on the
+# marginal constraints
+
+err0 = np.linalg.norm(gw0.sum(1) - p) + np.linalg.norm(gw0.sum(0) - q)
+err = np.linalg.norm(gw.sum(1) - p) + np.linalg.norm(gw.sum(0) - q)
+erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q)
+
+pl.figure(3, (10, 6))
+cmap = 'Blues'
+fontsize = 12
+pl.subplot(131)
+pl.imshow(gw0, cmap=cmap)
+pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % (
+ np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))),
+ fontsize=fontsize)
+
+pl.subplot(132)
+pl.imshow(gw, cmap=cmap)
+pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)),
+ fontsize=fontsize)
+
+pl.subplot(133)
+pl.imshow(gwe, cmap=cmap)
+pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)),
+ fontsize=fontsize)
+
+pl.tight_layout()
pl.show()
#############################################################################
#
-# Compute GW with a scalable stochastic method with any loss function
+# Compute GW with scalable stochastic methods with any loss function
# ----------------------------------------------------------------------
@@ -126,14 +186,14 @@ print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) print('Variance estimated: ' + str(slog['gw_dist_std']))
-pl.figure(1, (10, 5))
+pl.figure(4, (10, 5))
-pl.subplot(1, 2, 1)
-pl.imshow(pgw.toarray(), cmap='jet')
+pl.subplot(121)
+pl.imshow(pgw.toarray(), cmap=cmap)
pl.title('Pointwise Gromov Wasserstein')
-pl.subplot(1, 2, 2)
-pl.imshow(sgw, cmap='jet')
+pl.subplot(122)
+pl.imshow(sgw, cmap=cmap)
pl.title('Sampled Gromov Wasserstein')
pl.show()
|