From ee19d423adc85a960c9a46e4f81c370196805dbf Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 16 Feb 2018 15:04:04 +0100 Subject: update notebooks --- docs/source/auto_examples/plot_gromov.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) (limited to 'docs/source/auto_examples/plot_gromov.py') diff --git a/docs/source/auto_examples/plot_gromov.py b/docs/source/auto_examples/plot_gromov.py index d3f724c..9188da9 100644 --- a/docs/source/auto_examples/plot_gromov.py +++ b/docs/source/auto_examples/plot_gromov.py @@ -20,7 +20,7 @@ from mpl_toolkits.mplot3d import Axes3D # noqa import ot -############################################################################## +# # Sample two Gaussian distributions (2D and 3D) # --------------------------------------------- # @@ -43,7 +43,7 @@ P = sp.linalg.sqrtm(cov_t) xt = np.random.randn(n_samples, 3).dot(P) + mu_t -############################################################################## +# # Plotting the distributions # -------------------------- @@ -56,7 +56,7 @@ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r') pl.show() -############################################################################## +# # Compute distance kernels, normalize them and then display # --------------------------------------------------------- @@ -74,20 +74,32 @@ pl.subplot(122) pl.imshow(C2) pl.show() -############################################################################## +# # Compute Gromov-Wasserstein plans and distance # --------------------------------------------- - p = ot.unif(n_samples) q = ot.unif(n_samples) -gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4) -gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4) +gw0, log0 = ot.gromov.gromov_wasserstein( + C1, C2, p, q, 'square_loss', verbose=True, log=True) -print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist)) +gw, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True) -pl.figure() + +print('Gromov-Wasserstein distances: ' + str(log0['gw_dist'])) +print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(gw0, cmap='jet') +pl.title('Gromov Wasserstein') + +pl.subplot(1, 2, 2) pl.imshow(gw, cmap='jet') -pl.colorbar() +pl.title('Entropic Gromov Wasserstein') + pl.show() -- cgit v1.2.3