diff options
author | Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> | 2023-06-12 12:01:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 12:01:48 +0200 |
commit | 9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch) | |
tree | b7fda84880c5dabd1c441a1655741493e0683342 /examples/gromov/plot_fgw.py | |
parent | f0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff) |
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests
* add exemples for entropic srgw - srfgw solvers
* add PPA solvers for GW/FGW + complete previous commits
* update readme
* add tests
* add examples + tests + warning in entropic solvers + releases
* reduce testing runtimes for test_gromov
* fix conflicts
* optional marginals
* improve coverage
* gromov doc harmonization
* fix pep8
* complete optional marginal for entropic srfgw
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples/gromov/plot_fgw.py')
-rw-r--r-- | examples/gromov/plot_fgw.py | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py index bf10de6..68ecb13 100644 --- a/examples/gromov/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -4,13 +4,13 @@ Plot Fused-Gromov-Wasserstein ============================== -This example illustrates the computation of FGW for 1D measures [18]. +This example first illustrates the computation of FGW for 1D measures estimated +using a Conditional Gradient solver [24]. -[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. - """ # Author: Titouan Vayer <titouan.vayer@irisa.fr> @@ -24,11 +24,13 @@ import numpy as np import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein + ############################################################################## # Generate data # ------------- -#%% parameters +# parameters + # We create two 1D random measures n = 20 # number of points in the first distribution n2 = 30 # number of points in the second distribution @@ -53,10 +55,9 @@ q = ot.unif(n2) # Plot data # --------- -#%% plot the distributions +# plot the distributions -pl.close(10) -pl.figure(10, (7, 7)) +pl.figure(1, (7, 7)) pl.subplot(2, 1, 1) @@ -78,7 +79,7 @@ pl.show() # Create structure matrices and across-feature distance matrix # ------------------------------------------------------------ -#%% Structure matrices and across-features distance matrix +# Structure matrices and across-features distance matrix C1 = ot.dist(xs) C2 = ot.dist(xt) M = ot.dist(ys, yt) @@ -90,10 +91,9 @@ Got = ot.emd([], [], M) # Plot matrices # ------------- -#%% cmap = 'Reds' -pl.close(10) -pl.figure(10, (5, 5)) + +pl.figure(2, (5, 5)) fs = 15 l_x = [0, 5, 10, 15] l_y = [0, 5, 10, 15, 20, 25] @@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:]) pl.imshow(C2, cmap=cmap, interpolation='nearest') pl.title("$C_2$", fontsize=fs) pl.ylabel("$l$", fontsize=fs) -#pl.ylabel("$l$",fontsize=fs) pl.xticks(()) pl.yticks(l_y) ax2.set_aspect('auto') @@ -133,28 +132,27 @@ pl.show() # Compute FGW/GW # -------------- -#%% Computing FGW and GW +# Computing FGW and GW alpha = 1e-3 ot.tic() Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) ot.toc() -#%reload_ext WGW +# reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) ############################################################################## # Visualize transport matrices # ---------------------------- -#%% visu OT matrix +# visu OT matrix cmap = 'Blues' fs = 15 -pl.figure(2, (13, 5)) +pl.figure(3, (13, 5)) pl.clf() pl.subplot(1, 3, 1) pl.imshow(Got, cmap=cmap, interpolation='nearest') -#pl.xlabel("$y$",fontsize=fs) pl.ylabel("$i$", fontsize=fs) pl.xticks(()) |