summaryrefslogtreecommitdiff
path: root/examples/plot_gromov_barycenter.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/plot_gromov_barycenter.py')
-rwxr-xr-xexamples/plot_gromov_barycenter.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index c138031..52f4966 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -3,7 +3,7 @@
=====================================
Gromov-Wasserstein Barycenter example
=====================================
-This example is designed to show how to use the Gromov-Wassertsein distance
+This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
"""
@@ -34,8 +34,9 @@ that will be given by the output of the algorithm
def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
"""
- Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF
- multidimensional scaling (MDS) in specific dimensionned target space
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
Parameters
----------
@@ -51,7 +52,8 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
Returns
-------
npos : ndarray, shape (R, dim)
- Embedded coordinates of the interpolated point cloud (defined with one isometry)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
"""
rng = np.random.RandomState(seed=3)
@@ -88,10 +90,10 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
-cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
-star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
shapes = [square, cross, triangle, star]