summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/gromov/plot_gromov.py')
-rw-r--r--examples/gromov/plot_gromov.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()