From e1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 17:05:38 +0200 Subject: code review1 --- examples/plot_fgw.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) (limited to 'examples/plot_fgw.py') diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py index ae3c487..43efc94 100644 --- a/examples/plot_fgw.py +++ b/examples/plot_fgw.py @@ -22,12 +22,16 @@ import numpy as np import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein +############################################################################## +# Generate data +# --------- + #%% parameters # We create two 1D random measures -n = 20 -n2 = 30 -sig = 1 -sig2 = 0.1 +n = 20 # number of points in the first distribution +n2 = 30 # number of points in the second distribution +sig = 1 # std of first distribution +sig2 = 0.1 # std of second distribution np.random.seed(0) @@ -43,6 +47,10 @@ yt = yt[::-1, :] p = ot.unif(n) q = ot.unif(n2) +############################################################################## +# Plot data +# --------- + #%% plot the distributions pl.close(10) @@ -64,15 +72,22 @@ pl.yticks(()) pl.tight_layout() pl.show() +############################################################################## +# Create structure matrices and across-feature distance matrix +# --------- #%% Structure matrices and across-features distance matrix C1 = ot.dist(xs) -C2 = ot.dist(xt).T +C2 = ot.dist(xt) M = ot.dist(ys, yt) w1 = ot.unif(C1.shape[0]) w2 = ot.unif(C2.shape[0]) Got = ot.emd([], [], M) +############################################################################## +# Plot matrices +# --------- + #%% cmap = 'Reds' pl.close(10) @@ -112,6 +127,9 @@ pl.tight_layout() ax3.set_aspect('auto') pl.show() +############################################################################## +# Compute FGW/GW +# --------- #%% Computing FGW and GW alpha = 1e-3 @@ -123,6 +141,10 @@ ot.toc() #%reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) +############################################################################## +# Visualize transport matrices +# --------- + #%% visu OT matrix cmap = 'Blues' fs = 15 -- cgit v1.2.3