summaryrefslogtreecommitdiff
path: root/examples/plot_gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 13:53:34 +0100
commitf089a3cbc27c30ba9416ea1659c2fdbac1857146 (patch)
tree04905ed08da51190f15faa7aff67acb308ee583f /examples/plot_gromov.py
parent4a585de94109102c89bcd7ad43e35772e1027cd2 (diff)
better pep8 but not solved
Diffstat (limited to 'examples/plot_gromov.py')
-rw-r--r--examples/plot_gromov.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 5f2d826..9188da9 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -20,7 +20,7 @@ from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
-##############################################################################
+#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#
@@ -43,7 +43,7 @@ P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-##############################################################################
+#
# Plotting the distributions
# --------------------------
@@ -56,7 +56,7 @@ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
-##############################################################################
+#
# Compute distance kernels, normalize them and then display
# ---------------------------------------------------------
@@ -74,33 +74,32 @@ pl.subplot(122)
pl.imshow(C2)
pl.show()
-##############################################################################
+#
# Compute Gromov-Wasserstein plans and distance
# ---------------------------------------------
-#%%
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-gw0,log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True,log=True)
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
-gw,log= ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True)
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-pl.figure(1,(10,5))
+pl.figure(1, (10, 5))
-pl.subplot(1,2,1)
+pl.subplot(1, 2, 1)
pl.imshow(gw0, cmap='jet')
-pl.colorbar()
pl.title('Gromov Wasserstein')
-pl.subplot(1,2,2)
-pl.imshow(gw0, cmap='jet')
-pl.colorbar()
+pl.subplot(1, 2, 2)
+pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()