summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_fgw.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gromov/plot_fgw.py')
-rw-r--r--examples/gromov/plot_fgw.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index bf10de6..68ecb13 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -4,13 +4,13 @@
Plot Fused-Gromov-Wasserstein
==============================
-This example illustrates the computation of FGW for 1D measures [18].
+This example first illustrates the computation of FGW for 1D measures estimated
+using a Conditional Gradient solver [24].
-[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
-
"""
# Author: Titouan Vayer <titouan.vayer@irisa.fr>
@@ -24,11 +24,13 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
##############################################################################
# Generate data
# -------------
-#%% parameters
+# parameters
+
# We create two 1D random measures
n = 20 # number of points in the first distribution
n2 = 30 # number of points in the second distribution
@@ -53,10 +55,9 @@ q = ot.unif(n2)
# Plot data
# ---------
-#%% plot the distributions
+# plot the distributions
-pl.close(10)
-pl.figure(10, (7, 7))
+pl.figure(1, (7, 7))
pl.subplot(2, 1, 1)
@@ -78,7 +79,7 @@ pl.show()
# Create structure matrices and across-feature distance matrix
# ------------------------------------------------------------
-#%% Structure matrices and across-features distance matrix
+# Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
C2 = ot.dist(xt)
M = ot.dist(ys, yt)
@@ -90,10 +91,9 @@ Got = ot.emd([], [], M)
# Plot matrices
# -------------
-#%%
cmap = 'Reds'
-pl.close(10)
-pl.figure(10, (5, 5))
+
+pl.figure(2, (5, 5))
fs = 15
l_x = [0, 5, 10, 15]
l_y = [0, 5, 10, 15, 20, 25]
@@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:])
pl.imshow(C2, cmap=cmap, interpolation='nearest')
pl.title("$C_2$", fontsize=fs)
pl.ylabel("$l$", fontsize=fs)
-#pl.ylabel("$l$",fontsize=fs)
pl.xticks(())
pl.yticks(l_y)
ax2.set_aspect('auto')
@@ -133,28 +132,27 @@ pl.show()
# Compute FGW/GW
# --------------
-#%% Computing FGW and GW
+# Computing FGW and GW
alpha = 1e-3
ot.tic()
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
ot.toc()
-#%reload_ext WGW
+# reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
##############################################################################
# Visualize transport matrices
# ----------------------------
-#%% visu OT matrix
+# visu OT matrix
cmap = 'Blues'
fs = 15
-pl.figure(2, (13, 5))
+pl.figure(3, (13, 5))
pl.clf()
pl.subplot(1, 3, 1)
pl.imshow(Got, cmap=cmap, interpolation='nearest')
-#pl.xlabel("$y$",fontsize=fs)
pl.ylabel("$i$", fontsize=fs)
pl.xticks(())