summaryrefslogtreecommitdiff
path: root/examples/plot_fgw.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_fgw.py')
-rw-r--r--examples/plot_fgw.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
index ae3c487..43efc94 100644
--- a/examples/plot_fgw.py
+++ b/examples/plot_fgw.py
@@ -22,12 +22,16 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+##############################################################################
+# Generate data
+# ---------
+
#%% parameters
# We create two 1D random measures
-n = 20
-n2 = 30
-sig = 1
-sig2 = 0.1
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
np.random.seed(0)
@@ -43,6 +47,10 @@ yt = yt[::-1, :]
p = ot.unif(n)
q = ot.unif(n2)
+##############################################################################
+# Plot data
+# ---------
+
#%% plot the distributions
pl.close(10)
@@ -64,15 +72,22 @@ pl.yticks(())
pl.tight_layout()
pl.show()
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
-C2 = ot.dist(xt).T
+C2 = ot.dist(xt)
M = ot.dist(ys, yt)
w1 = ot.unif(C1.shape[0])
w2 = ot.unif(C2.shape[0])
Got = ot.emd([], [], M)
+##############################################################################
+# Plot matrices
+# ---------
+
#%%
cmap = 'Reds'
pl.close(10)
@@ -112,6 +127,9 @@ pl.tight_layout()
ax3.set_aspect('auto')
pl.show()
+##############################################################################
+# Compute FGW/GW
+# ---------
#%% Computing FGW and GW
alpha = 1e-3
@@ -123,6 +141,10 @@ ot.toc()
#%reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+##############################################################################
+# Visualize transport matrices
+# ---------
+
#%% visu OT matrix
cmap = 'Blues'
fs = 15