summaryrefslogtreecommitdiff
path: root/examples/plot_gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-28 15:04:04 +0200
committerNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-28 15:04:04 +0200
commit0a68bf4e83ee9092c3f3878115fea894922d7d56 (patch)
treee01c94d682e2b02d4ae900c6e7d6bcd408490e98 /examples/plot_gromov.py
parent7ab9037f1e4a08439083d1bc5705be5ed2e9e10a (diff)
gromov:flake8 and other
Diffstat (limited to 'examples/plot_gromov.py')
-rw-r--r--examples/plot_gromov.py63
1 files changed, 27 insertions, 36 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 11e5336..a33fde1 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -3,11 +3,8 @@
====================
Gromov-Wasserstein example
====================
-
-This example is designed to show how to use the Gromov-Wassertsein distance
-computation in POT.
-
-
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
@@ -20,43 +17,38 @@ import numpy as np
import ot
import matplotlib.pylab as pl
-from mpl_toolkits.mplot3d import Axes3D
-
"""
Sample two Gaussian distributions (2D and 3D)
====================
-
-The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For
-demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
-
+The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
+For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
"""
-n=30 # nb samples
-mu_s=np.array([0,0])
-cov_s=np.array([[1,0],[0,1]])
+n = 30 # nb samples
-mu_t=np.array([4,4,4])
-cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]])
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
-xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
-P=sp.linalg.sqrtm(cov_t)
-xt= np.random.randn(n,3).dot(P)+mu_t
-
+xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n, 3).dot(P) + mu_t
"""
Plotting the distributions
====================
"""
-fig=pl.figure()
-ax1=fig.add_subplot(121)
-ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-ax2=fig.add_subplot(122,projection='3d')
-ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r')
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
@@ -65,11 +57,11 @@ Compute distance kernels, normalize them and then display
====================
"""
-C1=sp.spatial.distance.cdist(xs,xs)
-C2=sp.spatial.distance.cdist(xt,xt)
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
-C1/=C1.max()
-C2/=C2.max()
+C1 /= C1.max()
+C2 /= C2.max()
pl.figure()
pl.subplot(121)
@@ -83,16 +75,15 @@ Compute Gromov-Wasserstein plans and distance
====================
"""
-p=ot.unif(n)
-q=ot.unif(n)
+p = ot.unif(n)
+q = ot.unif(n)
-gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4)
-gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4)
+gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
-print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist))
+print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
pl.figure()
-pl.imshow(gw,cmap='jet')
+pl.imshow(gw, cmap='jet')
pl.colorbar()
pl.show()
-