summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-09-12 18:07:17 +0200
committerNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-09-12 18:07:17 +0200
commit36bf599552ff15d1ca1c6b505507e65a333fa55e (patch)
tree8d508e9800bffef8c008c0872e46081c423d75dd /examples
parent8ea74ad41d660629a12f7d8d0d8816a23d385a92 (diff)
Corrections on Gromov
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_gromov.py5
-rwxr-xr-xexamples/plot_gromov_barycenter.py13
2 files changed, 8 insertions, 10 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 92312ae..0f839a3 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -22,8 +22,9 @@ import ot
"""
Sample two Gaussian distributions (2D and 3D)
=============================================
-The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
-For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
+two Gaussian distributions in 2- and 3-dimensional spaces.
"""
n_samples = 30 # nb samples
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index 4f17117..c138031 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -48,13 +48,10 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
eps : float
relative tolerance w.r.t stress to declare converge
-
Returns
-------
npos : ndarray, shape (R, dim)
Embedded coordinates of the interpolated point cloud (defined with one isometry)
-
-
"""
rng = np.random.RandomState(seed=3)
@@ -91,12 +88,12 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/carre.png').astype(np.float64) / 256
-circle = spi.imread('../data/rond.png').astype(np.float64) / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
-arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
+square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
-shapes = [square, circle, triangle, arrow]
+shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]