From 9076f02903ba2fb9ea9fe704764a755cad8dcd63 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Mon, 12 Jun 2023 12:01:48 +0200 Subject: [FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary --- examples/gromov/plot_gromov.py | 112 +++++++++++++++++++++++++++++++---------- 1 file changed, 86 insertions(+), 26 deletions(-) (limited to 'examples/gromov/plot_gromov.py') diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index afb5bdc..252267f 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -5,13 +5,38 @@ Gromov-Wasserstein example ========================== This example is designed to show how to use the Gromov-Wasserstein distance computation in POT. +We first compare 3 solvers to estimate the distance based on +Conditional Gradient [24] or Sinkhorn projections [12, 51]. +Then we compare 2 stochastic solvers to estimate the distance with a lower +numerical cost [33]. + +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), +"Gromov-Wasserstein averaging of kernel and distance matrices". +International Conference on Machine Learning (ICML). + +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[33] Kerdoncuff T., Emonet R., Marc S. "Sampled Gromov Wasserstein", +Machine Learning Journal (MJL), 2021. + +[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). +"Gromov-wasserstein learning for graph matching and node embedding". +In International Conference on Machine Learning (ICML), 2019. + """ # Author: Erwan Vautier # Nicolas Courty +# Cédric Vincent-Cuaz +# Tanguy Kerdoncuff # # License: MIT License +# sphinx_gallery_thumbnail_number = 1 + import scipy as sp import numpy as np import matplotlib.pylab as pl @@ -36,7 +61,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4, 4]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - +np.random.seed(0) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) P = sp.linalg.sqrtm(cov_t) xt = np.random.randn(n_samples, 3).dot(P) + mu_t @@ -47,7 +72,7 @@ xt = np.random.randn(n_samples, 3).dot(P) + mu_t # -------------------------- -fig = pl.figure() +fig = pl.figure(1) ax1 = fig.add_subplot(121) ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') ax2 = fig.add_subplot(122, projection='3d') @@ -66,11 +91,15 @@ C2 = sp.spatial.distance.cdist(xt, xt) C1 /= C1.max() C2 /= C2.max() -pl.figure() +pl.figure(2) pl.subplot(121) pl.imshow(C1) +pl.title('C1') + pl.subplot(122) pl.imshow(C2) +pl.title('C2') + pl.show() ############################################################################# @@ -81,32 +110,63 @@ pl.show() p = ot.unif(n_samples) q = ot.unif(n_samples) +# Conditional Gradient algorithm gw0, log0 = ot.gromov.gromov_wasserstein( C1, C2, p, q, 'square_loss', verbose=True, log=True) +# Proximal Point algorithm with Kullback-Leibler as proximal operator gw, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True) - - -print('Gromov-Wasserstein distances: ' + str(log0['gw_dist'])) -print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist'])) - - -pl.figure(1, (10, 5)) - -pl.subplot(1, 2, 1) -pl.imshow(gw0, cmap='jet') -pl.title('Gromov Wasserstein') - -pl.subplot(1, 2, 2) -pl.imshow(gw, cmap='jet') -pl.title('Entropic Gromov Wasserstein') - + C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA', + log=True, verbose=True) + +# Projected Gradient algorithm with entropic regularization +gwe, loge = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD', + log=True, verbose=True) + +print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist'])) +print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist'])) +print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist'])) + +# compute OT sparsity level +gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2) +gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2) +gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2) + +# Methods using Sinkhorn projections tend to produce feasibility errors on the +# marginal constraints + +err0 = np.linalg.norm(gw0.sum(1) - p) + np.linalg.norm(gw0.sum(0) - q) +err = np.linalg.norm(gw.sum(1) - p) + np.linalg.norm(gw.sum(0) - q) +erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q) + +pl.figure(3, (10, 6)) +cmap = 'Blues' +fontsize = 12 +pl.subplot(131) +pl.imshow(gw0, cmap=cmap) +pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % ( + np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))), + fontsize=fontsize) + +pl.subplot(132) +pl.imshow(gw, cmap=cmap) +pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( + np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)), + fontsize=fontsize) + +pl.subplot(133) +pl.imshow(gwe, cmap=cmap) +pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( + np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)), + fontsize=fontsize) + +pl.tight_layout() pl.show() ############################################################################# # -# Compute GW with a scalable stochastic method with any loss function +# Compute GW with scalable stochastic methods with any loss function # ---------------------------------------------------------------------- @@ -126,14 +186,14 @@ print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) print('Variance estimated: ' + str(slog['gw_dist_std'])) -pl.figure(1, (10, 5)) +pl.figure(4, (10, 5)) -pl.subplot(1, 2, 1) -pl.imshow(pgw.toarray(), cmap='jet') +pl.subplot(121) +pl.imshow(pgw.toarray(), cmap=cmap) pl.title('Pointwise Gromov Wasserstein') -pl.subplot(1, 2, 2) -pl.imshow(sgw, cmap='jet') +pl.subplot(122) +pl.imshow(sgw, cmap=cmap) pl.title('Sampled Gromov Wasserstein') pl.show() -- cgit v1.2.3