diff options
author | Nicolas Courty <Nico@pc-mna-08.univ-ubs.fr> | 2017-09-12 18:07:17 +0200 |
---|---|---|
committer | Nicolas Courty <Nico@pc-mna-08.univ-ubs.fr> | 2017-09-12 18:07:17 +0200 |
commit | 36bf599552ff15d1ca1c6b505507e65a333fa55e (patch) | |
tree | 8d508e9800bffef8c008c0872e46081c423d75dd /examples | |
parent | 8ea74ad41d660629a12f7d8d0d8816a23d385a92 (diff) |
Corrections on Gromov
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_gromov.py | 5 | ||||
-rwxr-xr-x | examples/plot_gromov_barycenter.py | 13 |
2 files changed, 8 insertions, 10 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py index 92312ae..0f839a3 100644 --- a/examples/plot_gromov.py +++ b/examples/plot_gromov.py @@ -22,8 +22,9 @@ import ot """
Sample two Gaussian distributions (2D and 3D)
=============================================
-The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
-For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
+two Gaussian distributions in 2- and 3-dimensional spaces.
"""
n_samples = 30 # nb samples
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py index 4f17117..c138031 100755 --- a/examples/plot_gromov_barycenter.py +++ b/examples/plot_gromov_barycenter.py @@ -48,13 +48,10 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9): eps : float
relative tolerance w.r.t stress to declare converge
-
Returns
-------
npos : ndarray, shape (R, dim)
Embedded coordinates of the interpolated point cloud (defined with one isometry)
-
-
"""
rng = np.random.RandomState(seed=3)
@@ -91,12 +88,12 @@ def im2mat(I): return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/carre.png').astype(np.float64) / 256
-circle = spi.imread('../data/rond.png').astype(np.float64) / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
-arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
+square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
-shapes = [square, circle, triangle, arrow]
+shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
|