summaryrefslogtreecommitdiff
path: root/examples/gromov/plot_gromov_barycenter.py
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2021-11-03 08:41:35 +0100
committerGitHub <noreply@github.com>2021-11-03 08:41:35 +0100
commite1b67c641da3b3e497db6811af2c200022b10302 (patch)
tree44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples/gromov/plot_gromov_barycenter.py
parent61340d526702616ff000d9e1cf71f52dd199a103 (diff)
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples/gromov/plot_gromov_barycenter.py')
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py90
1 files changed, 47 insertions, 43 deletions
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index e2d88ba..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -89,17 +91,19 @@ def im2mat(img):
return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')