summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 15:04:04 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 15:04:04 +0100
commitee19d423adc85a960c9a46e4f81c370196805dbf (patch)
tree1c0bc21a605d0097616c26cfdb846fc744ed43a0 /docs/source/auto_examples/plot_gromov.py
parentefdbf9e4fe9295fb1bec893e8aaa9102537cb7f5 (diff)
update notebooks
Diffstat (limited to 'docs/source/auto_examples/plot_gromov.py')
-rw-r--r--docs/source/auto_examples/plot_gromov.py32
1 files changed, 22 insertions, 10 deletions
diff --git a/docs/source/auto_examples/plot_gromov.py b/docs/source/auto_examples/plot_gromov.py
index d3f724c..9188da9 100644
--- a/docs/source/auto_examples/plot_gromov.py
+++ b/docs/source/auto_examples/plot_gromov.py
@@ -20,7 +20,7 @@ from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
-##############################################################################
+#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#
@@ -43,7 +43,7 @@ P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-##############################################################################
+#
# Plotting the distributions
# --------------------------
@@ -56,7 +56,7 @@ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
-##############################################################################
+#
# Compute distance kernels, normalize them and then display
# ---------------------------------------------------------
@@ -74,20 +74,32 @@ pl.subplot(122)
pl.imshow(C2)
pl.show()
-##############################################################################
+#
# Compute Gromov-Wasserstein plans and distance
# ---------------------------------------------
-
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
-gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
-print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-pl.figure()
+
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(gw0, cmap='jet')
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
pl.imshow(gw, cmap='jet')
-pl.colorbar()
+pl.title('Entropic Gromov Wasserstein')
+
pl.show()