summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/gromov/plot_entropic_semirelaxed_fgw.py304
-rw-r--r--examples/gromov/plot_fgw.py32
-rw-r--r--examples/gromov/plot_fgw_solvers.py288
-rw-r--r--examples/gromov/plot_gromov.py112
4 files changed, 693 insertions, 43 deletions
diff --git a/examples/gromov/plot_entropic_semirelaxed_fgw.py b/examples/gromov/plot_entropic_semirelaxed_fgw.py
new file mode 100644
index 0000000..642baea
--- /dev/null
+++ b/examples/gromov/plot_entropic_semirelaxed_fgw.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein
+and the entropic semi-relaxed Fused Gromov-Wasserstein divergences.
+
+Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
+G2 at a minimal entropic-regularized (F)GW distance from G1.
+
+First, we generate two graphs following Stochastic Block Models, then show
+how to compute their srGW matchings and illustrate them. These graphs are then
+endowed with node features and we follow the same process with srFGW.
+
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+International Conference on Learning Representations (ICLR), 2021.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+#############################################################################
+#
+# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+# 0) GW(C2, h2, C3, h3) for reference
+OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
+gw = log['gw_dist']
+
+# 1) srGW_e(C2, h2, C3)
+OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein(
+ C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True)
+srgw_23 = log_23['srgw_dist']
+
+# 2) srGW_e(C3, h3, C2)
+
+OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein(
+ C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True)
+srgw_32 = log_32['srgw_dist']
+
+print('GW(C2, C3) = ', gw)
+print('srGW_e(C2, h2, C3) = ', srgw_23)
+print('srGW_e(C3, h3, C2) = ', srgw_32)
+
+
+#############################################################################
+#
+# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the entropic srGW matching.
+# We adjust the intensity of links across domains proportionaly to the mass
+# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
+ p1, p2, T, pos1=None, pos2=None,
+ shiftx=4, switchx=False, node_size=70,
+ seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+ for k1, v1 in pos1.items():
+ max_Tk1 = np.max(T[k1, :])
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 10
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(1, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
+
+#############################################################################
+#
+# Add node features
+# ---------------------------------------------
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+#############################################################################
+#
+# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+alpha = 0.5
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference
+
+OT, log = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
+fgw = log['fgw_dist']
+
+# 1) srFGW_e(C2, F2, h2, C3, F3)
+OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None)
+srfgw_23 = log_23['srfgw_dist']
+
+# 2) srFGW(C3, F3, h3, C2, F2)
+
+OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein(
+ M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None)
+srfgw_32 = log_32['srfgw_dist']
+
+print('FGW(C2, F2, C3, F3) = ', fgw)
+print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23)
+print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32)
+
+#############################################################################
+#
+# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srFGW matching
+# NB: colors refer to clusters - not to node features
+
+pl.figure(2, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index bf10de6..68ecb13 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -4,13 +4,13 @@
Plot Fused-Gromov-Wasserstein
==============================
-This example illustrates the computation of FGW for 1D measures [18].
+This example first illustrates the computation of FGW for 1D measures estimated
+using a Conditional Gradient solver [24].
-[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
-
"""
# Author: Titouan Vayer <titouan.vayer@irisa.fr>
@@ -24,11 +24,13 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
##############################################################################
# Generate data
# -------------
-#%% parameters
+# parameters
+
# We create two 1D random measures
n = 20 # number of points in the first distribution
n2 = 30 # number of points in the second distribution
@@ -53,10 +55,9 @@ q = ot.unif(n2)
# Plot data
# ---------
-#%% plot the distributions
+# plot the distributions
-pl.close(10)
-pl.figure(10, (7, 7))
+pl.figure(1, (7, 7))
pl.subplot(2, 1, 1)
@@ -78,7 +79,7 @@ pl.show()
# Create structure matrices and across-feature distance matrix
# ------------------------------------------------------------
-#%% Structure matrices and across-features distance matrix
+# Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
C2 = ot.dist(xt)
M = ot.dist(ys, yt)
@@ -90,10 +91,9 @@ Got = ot.emd([], [], M)
# Plot matrices
# -------------
-#%%
cmap = 'Reds'
-pl.close(10)
-pl.figure(10, (5, 5))
+
+pl.figure(2, (5, 5))
fs = 15
l_x = [0, 5, 10, 15]
l_y = [0, 5, 10, 15, 20, 25]
@@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:])
pl.imshow(C2, cmap=cmap, interpolation='nearest')
pl.title("$C_2$", fontsize=fs)
pl.ylabel("$l$", fontsize=fs)
-#pl.ylabel("$l$",fontsize=fs)
pl.xticks(())
pl.yticks(l_y)
ax2.set_aspect('auto')
@@ -133,28 +132,27 @@ pl.show()
# Compute FGW/GW
# --------------
-#%% Computing FGW and GW
+# Computing FGW and GW
alpha = 1e-3
ot.tic()
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
ot.toc()
-#%reload_ext WGW
+# reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
##############################################################################
# Visualize transport matrices
# ----------------------------
-#%% visu OT matrix
+# visu OT matrix
cmap = 'Blues'
fs = 15
-pl.figure(2, (13, 5))
+pl.figure(3, (13, 5))
pl.clf()
pl.subplot(1, 3, 1)
pl.imshow(Got, cmap=cmap, interpolation='nearest')
-#pl.xlabel("$y$",fontsize=fs)
pl.ylabel("$i$", fontsize=fs)
pl.xticks(())
diff --git a/examples/gromov/plot_fgw_solvers.py b/examples/gromov/plot_fgw_solvers.py
new file mode 100644
index 0000000..5f8a885
--- /dev/null
+++ b/examples/gromov/plot_fgw_solvers.py
@@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Comparison of Fused Gromov-Wasserstein solvers
+==============================
+
+This example illustrates the computation of FGW for attributed graphs
+using 3 different solvers to estimate the distance based on Conditional
+Gradient [24] or Sinkhorn projections [12, 51].
+
+We generate two graphs following Stochastic Block Models further endowed with
+node features and compute their FGW matchings.
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+"Gromov-Wasserstein averaging of kernel and distance matrices".
+International Conference on Machine Learning (ICML).
+
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
+"Gromov-wasserstein learning for graph matching and node embedding".
+In International Conference on Machine Learning (ICML), 2019.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+np.random.seed(0)
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+#############################################################################
+#
+# Compute their Fused Gromov-Wasserstein distances
+# ---------------------------------------------
+
+alpha = 0.5
+
+
+# Conditional Gradient algorithm
+fgw0, log0 = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True)
+
+# Proximal Point algorithm with Kullback-Leibler as proximal operator
+fgw, log = entropic_fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
+ log=True, verbose=True, warmstart=False, numItermax=10)
+
+# Projected Gradient algorithm with entropic regularization
+fgwe, loge = entropic_fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
+ log=True, verbose=True, warmstart=False, numItermax=10)
+
+print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist']))
+print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist']))
+print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist']))
+
+# compute OT sparsity level
+fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3)
+fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3)
+fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3)
+
+# Methods using Sinkhorn projections tend to produce feasibility errors on the
+# marginal constraints
+
+err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3)
+err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3)
+erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3)
+
+#############################################################################
+#
+# Visualization of the Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the FGW matchings
+# We adjust the intensity of links across domains proportionaly to the mass
+# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
+# For each matching, all node sizes are proportionnal to their mass computed
+# from marginals of the OT plan to illustrate potential feasibility errors.
+# NB: colors refer to clusters - not to node features
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
+ pos1=None, pos2=None, shiftx=4, switchx=False,
+ node_size=70, seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+
+ for k1, v1 in pos1.items():
+ max_Tk1 = np.max(T[k1, :])
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.7, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 13
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(2, figsize=(12, 3.5))
+pl.clf()
+pl.subplot(131)
+pl.axis('off')
+pl.axis
+pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %',
+ np.round(err0, 4)), fontsize=fontsize)
+
+p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+
+pl.subplot(132)
+pl.axis('off')
+
+p, q = fgw.sum(1), fgw.sum(0) # check marginals
+
+pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %',
+ np.round(err, 4)), fontsize=fontsize)
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw,
+ pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
+
+pl.subplot(133)
+pl.axis('off')
+
+pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals
+
+pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %',
+ np.round(erre, 4)), fontsize=fontsize)
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe,
+ pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
+
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index afb5bdc..252267f 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -5,13 +5,38 @@ Gromov-Wasserstein example
==========================
This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
+We first compare 3 solvers to estimate the distance based on
+Conditional Gradient [24] or Sinkhorn projections [12, 51].
+Then we compare 2 stochastic solvers to estimate the distance with a lower
+numerical cost [33].
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+"Gromov-Wasserstein averaging of kernel and distance matrices".
+International Conference on Machine Learning (ICML).
+
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
+
+[33] Kerdoncuff T., Emonet R., Marc S. "Sampled Gromov Wasserstein",
+Machine Learning Journal (MJL), 2021.
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
+"Gromov-wasserstein learning for graph matching and node embedding".
+In International Conference on Machine Learning (ICML), 2019.
+
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+# Tanguy Kerdoncuff <tanguy.kerdoncuff@laposte.net>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 1
+
import scipy as sp
import numpy as np
import matplotlib.pylab as pl
@@ -36,7 +61,7 @@ cov_s = np.array([[1, 0], [0, 1]])
mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
-
+np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
@@ -47,7 +72,7 @@ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
# --------------------------
-fig = pl.figure()
+fig = pl.figure(1)
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
@@ -66,11 +91,15 @@ C2 = sp.spatial.distance.cdist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
-pl.figure()
+pl.figure(2)
pl.subplot(121)
pl.imshow(C1)
+pl.title('C1')
+
pl.subplot(122)
pl.imshow(C2)
+pl.title('C2')
+
pl.show()
#############################################################################
@@ -81,32 +110,63 @@ pl.show()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+# Conditional Gradient algorithm
gw0, log0 = ot.gromov.gromov_wasserstein(
C1, C2, p, q, 'square_loss', verbose=True, log=True)
+# Proximal Point algorithm with Kullback-Leibler as proximal operator
gw, log = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-
-
-print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
-print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-
-
-pl.figure(1, (10, 5))
-
-pl.subplot(1, 2, 1)
-pl.imshow(gw0, cmap='jet')
-pl.title('Gromov Wasserstein')
-
-pl.subplot(1, 2, 2)
-pl.imshow(gw, cmap='jet')
-pl.title('Entropic Gromov Wasserstein')
-
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA',
+ log=True, verbose=True)
+
+# Projected Gradient algorithm with entropic regularization
+gwe, loge = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD',
+ log=True, verbose=True)
+
+print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist']))
+print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist']))
+print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist']))
+
+# compute OT sparsity level
+gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2)
+
+# Methods using Sinkhorn projections tend to produce feasibility errors on the
+# marginal constraints
+
+err0 = np.linalg.norm(gw0.sum(1) - p) + np.linalg.norm(gw0.sum(0) - q)
+err = np.linalg.norm(gw.sum(1) - p) + np.linalg.norm(gw.sum(0) - q)
+erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q)
+
+pl.figure(3, (10, 6))
+cmap = 'Blues'
+fontsize = 12
+pl.subplot(131)
+pl.imshow(gw0, cmap=cmap)
+pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % (
+ np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))),
+ fontsize=fontsize)
+
+pl.subplot(132)
+pl.imshow(gw, cmap=cmap)
+pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)),
+ fontsize=fontsize)
+
+pl.subplot(133)
+pl.imshow(gwe, cmap=cmap)
+pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)),
+ fontsize=fontsize)
+
+pl.tight_layout()
pl.show()
#############################################################################
#
-# Compute GW with a scalable stochastic method with any loss function
+# Compute GW with scalable stochastic methods with any loss function
# ----------------------------------------------------------------------
@@ -126,14 +186,14 @@ print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
print('Variance estimated: ' + str(slog['gw_dist_std']))
-pl.figure(1, (10, 5))
+pl.figure(4, (10, 5))
-pl.subplot(1, 2, 1)
-pl.imshow(pgw.toarray(), cmap='jet')
+pl.subplot(121)
+pl.imshow(pgw.toarray(), cmap=cmap)
pl.title('Pointwise Gromov Wasserstein')
-pl.subplot(1, 2, 2)
-pl.imshow(sgw, cmap='jet')
+pl.subplot(122)
+pl.imshow(sgw, cmap=cmap)
pl.title('Sampled Gromov Wasserstein')
pl.show()