summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_gromov.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-11-09 17:06:47 +0100
committerGard Spreemann <gspr@nonempty.org>2021-11-09 17:06:47 +0100
commit3d10287c776e95427ace867b302cff02488694ca (patch)
tree5735b6434fe7a14b775d266e1a5d7720b56912e4 /examples/gromov/plot_gromov.py
parentcc703fc5e204a4b1c03fc29e59687e6b97aa7f67 (diff)
parent1a283cb0c77f79d6f36de7c01fa61dc8d9696bca (diff)
Merge branch 'dfsg/latest' into debian/sid
Diffstat (limited to 'examples/gromov/plot_gromov.py')
-rw-r--r--examples/gromov/plot_gromov.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()