diff options
author | Tanguy <tanguy.kerdoncuff@laposte.net> | 2021-09-17 18:36:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-17 18:36:33 +0200 |
commit | e0ba31ce39a7d9e65e50ea970a574b3db54e4207 (patch) | |
tree | 36c95fc33bd07be476c44f8b5ea65896cf1f0c9f /examples/gromov | |
parent | 96bf1a46e74d6985419e14222afb0b9241a7bb36 (diff) |
[MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275)
* Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein.
* Correct some lines in SaGroW and PoGroW to follow pep8 guide.
* Change nb_samples name. Use rdm state. Change symmetric check.
* Change names of len(p) and len(q) in SaGroW and PoGroW.
* Re-add some deleted lines in the comments of gromov.py
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples/gromov')
-rw-r--r-- | examples/gromov/plot_gromov.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f86..5a362cf 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet') pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()
|