summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_gromov.py')
-rw-r--r--docs/source/auto_examples/plot_gromov.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/docs/source/auto_examples/plot_gromov.py b/docs/source/auto_examples/plot_gromov.py
index d3f724c..5cd40f6 100644
--- a/docs/source/auto_examples/plot_gromov.py
+++ b/docs/source/auto_examples/plot_gromov.py
@@ -19,8 +19,8 @@ import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
-
-##############################################################################
+#############################################################################
+#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#
@@ -42,8 +42,8 @@ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-
-##############################################################################
+#############################################################################
+#
# Plotting the distributions
# --------------------------
@@ -55,8 +55,8 @@ ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
-
-##############################################################################
+#############################################################################
+#
# Compute distance kernels, normalize them and then display
# ---------------------------------------------------------
@@ -74,20 +74,33 @@ pl.subplot(122)
pl.imshow(C2)
pl.show()
-##############################################################################
+#############################################################################
+#
# Compute Gromov-Wasserstein plans and distance
# ---------------------------------------------
-
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
-gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
-print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-pl.figure()
+
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(gw0, cmap='jet')
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
pl.imshow(gw, cmap='jet')
-pl.colorbar()
+pl.title('Entropic Gromov Wasserstein')
+
pl.show()