summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2021-11-03 08:41:35 +0100
committerGitHub <noreply@github.com>2021-11-03 08:41:35 +0100
commite1b67c641da3b3e497db6811af2c200022b10302 (patch)
tree44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples
parent61340d526702616ff000d9e1cf71f52dd199a103 (diff)
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples')
-rw-r--r--examples/barycenters/plot_barycenter_1D.py63
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py53
-rw-r--r--examples/barycenters/plot_debiased_barycenter.py131
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py118
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py73
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py118
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py90
8 files changed, 390 insertions, 258 deletions
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 63dc460..2373e99 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 1
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -51,18 +51,6 @@ M = ot.utils.dist0(n)
M /= M.max()
##############################################################################
-# Plot data
-# ---------
-
-#%% plot the distributions
-
-pl.figure(1, figsize=(6.4, 3))
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
-pl.tight_layout()
-
-##############################################################################
# Barycenter computation
# ----------------------
@@ -78,24 +66,20 @@ bary_l2 = A.dot(weights)
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
-pl.figure(2)
-pl.clf()
-pl.subplot(2, 1, 1)
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
+f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
+ax1.plot(x, A, color="black")
+ax1.set_title('Distributions')
-pl.subplot(2, 1, 2)
-pl.plot(x, bary_l2, 'r', label='l2')
-pl.plot(x, bary_wass, 'g', label='Wasserstein')
-pl.legend()
-pl.title('Barycenters')
-pl.tight_layout()
+ax2.plot(x, bary_l2, 'r', label='l2')
+ax2.plot(x, bary_wass, 'g', label='Wasserstein')
+ax2.set_title('Barycenters')
+
+plt.legend()
+plt.show()
##############################################################################
# Barycentric interpolation
# -------------------------
-
#%% barycenter interpolation
n_alpha = 11
@@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha))
B_wass = np.copy(B_l2)
-for i in range(0, n_alpha):
+for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
+plt.figure(2)
-pl.figure(3)
-
-cmap = pl.cm.get_cmap('viridis')
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with l2')
-pl.tight_layout()
+plt.title('Barycenter interpolation with l2')
+plt.tight_layout()
-pl.figure(4)
-cmap = pl.cm.get_cmap('viridis')
+plt.figure(3)
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with Wasserstein')
-pl.tight_layout()
+plt.title('Barycenter interpolation with Wasserstein')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index 57a6bac..6502f16 100644
--- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
-1D Wasserstein barycenter comparison between exact LP and entropic regularization
+1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index cbcd4a1..3721f31 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================
-This example is designed to illustrate how the Convolutional Wasserstein Barycenter
-function of POT works.
+This example is designed to illustrate how the Convolutional Wasserstein
+Barycenter function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
-
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+import matplotlib.pyplot as plt
import ot
##############################################################################
@@ -25,22 +26,19 @@ import ot
#
# The four distributions are constructed from 4 simple images
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
-A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
-A.append(f1)
-A.append(f2)
-A.append(f3)
-A.append(f4)
-A = np.array(A)
+A = np.array([f1, f2, f3, f4])
nb_images = 5
@@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1))
# ----------------------------------------
#
-pl.figure(figsize=(10, 10))
-pl.title('Convolutional Wasserstein Barycenters in POT')
+fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
+plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
- pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)
@@ -74,19 +71,19 @@ for i in range(nb_images):
weights = (1 - ty) * tmp1 + ty * tmp2
if i == 0 and j == 0:
- pl.imshow(f1, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
- pl.imshow(f3, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
- pl.imshow(f2, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
- pl.imshow(f4, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
- pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
- pl.axis('off')
-pl.show()
+ axes[i, j].imshow(
+ ot.bregman.convolutional_barycenter2d(A, reg, weights),
+ cmap=cm
+ )
+ axes[i, j].axis('off')
+plt.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py
new file mode 100644
index 0000000..2a603dd
--- /dev/null
+++ b/examples/barycenters/plot_debiased_barycenter.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Debiased Sinkhorn barycenter demo
+=================================
+
+This example illustrates the computation of the debiased Sinkhorn barycenter
+as proposed in [37]_.
+
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+"""
+
+# Author: Hicham Janati <hicham.janati100@gmail.com>
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
+import os
+from pathlib import Path
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import ot
+from ot.bregman import (barycenter, barycenter_debiased,
+ convolutional_barycenter2d,
+ convolutional_barycenter2d_debiased)
+
+##############################################################################
+# Debiased barycenter of 1D Gaussians
+# ------------------------------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+epsilons = [5e-3, 1e-2, 5e-2]
+
+
+bars = [barycenter(A, M, reg, weights) for reg in epsilons]
+bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
+labels = ["Sinkhorn barycenter", "Debiased barycenter"]
+colors = ["indianred", "gold"]
+
+f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
+ figsize=(12, 4), num=1)
+for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
+ ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
+ ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
+ for data, label, color in zip([bar, bar_debiased], labels, colors):
+ ax.plot(data, color=color, label=label, lw=2)
+ ax.set_title(r"$\varepsilon = %.3f$" % eps)
+plt.legend()
+plt.show()
+
+
+##############################################################################
+# Debiased barycenter of 2D images
+# ---------------------------------
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+
+A = np.asarray([f1, f2]) + 1e-2
+A /= A.sum(axis=(1, 2))[:, None, None]
+
+##############################################################################
+# Display the input images
+
+fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
+for ax, img in zip(axes, A):
+ ax.imshow(img, cmap="Greys")
+ ax.axis("off")
+fig.tight_layout()
+plt.show()
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+bars_sinkhorn, bars_debiased = [], []
+epsilons = [5e-3, 7e-3, 1e-2]
+for eps in epsilons:
+ bar = convolutional_barycenter2d(A, eps)
+ bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
+ bars_sinkhorn.append(bar)
+ bars_debiased.append(bar_debiased)
+
+titles = ["Sinkhorn", "Debiased"]
+all_bars = [bars_sinkhorn, bars_debiased]
+fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
+for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
+ for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
+ ax.imshow(img, cmap="Greys")
+ if jj == 0:
+ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ if ii == 0:
+ ax.set_ylabel(method, fontsize=15)
+fig.tight_layout()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index 6218b13..06dc8ab 100644
--- a/examples/domain-adaptation/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
+
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
def im2mat(img):
@@ -46,16 +49,19 @@ def minmax(img):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
nb = 500
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -65,39 +71,39 @@ Xt = X2[idx2, :]
# Plot original image
# -------------------
-pl.figure(1, figsize=(6.4, 3))
+plt.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
##############################################################################
# Scatter plot of colors
# ----------------------
-pl.figure(2, figsize=(6.4, 3))
+plt.figure(2, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
@@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
# Plot new images
# ---------------
-pl.figure(3, figsize=(8, 4))
+plt.figure(3, figsize=(8, 4))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(2, 3, 2)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Image 1 Adapt')
+plt.subplot(2, 3, 2)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Image 1 Adapt')
-pl.subplot(2, 3, 3)
-pl.imshow(I1te)
-pl.axis('off')
-pl.title('Image 1 Adapt (reg)')
+plt.subplot(2, 3, 3)
+plt.imshow(I1te)
+plt.axis('off')
+plt.title('Image 1 Adapt (reg)')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
-pl.subplot(2, 3, 5)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Image 2 Adapt')
+plt.subplot(2, 3, 5)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Image 2 Adapt')
-pl.subplot(2, 3, 6)
-pl.imshow(I2te)
-pl.axis('off')
-pl.title('Image 2 Adapt (reg)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(I2te)
+plt.axis('off')
+plt.title('Image 2 Adapt (reg)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index be47510..a44096a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -13,9 +13,11 @@ Linear OT mapping estimation
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+from matplotlib import pyplot as plt
import ot
##############################################################################
@@ -26,17 +28,19 @@ n = 1000
d = 2
sigma = .1
+rng = np.random.RandomState(42)
+
# source samples
-angles = np.random.rand(n, 1) * 2 * np.pi
+angles = rng.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xs[:n // 2, 1] += 2
# target samples
-anglet = np.random.rand(n, 1) * 2 * np.pi
+anglet = rng.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xt[:n // 2, 1] += 2
@@ -48,9 +52,9 @@ xt = xt.dot(A) + b
# Plot data
# ---------
-pl.figure(1, (5, 5))
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
+plt.figure(1, (5, 5))
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
##############################################################################
@@ -66,13 +70,13 @@ xst = xs.dot(Ae) + be
# Plot transported samples
# ------------------------
-pl.figure(1, (5, 5))
-pl.clf()
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
-pl.plot(xst[:, 0], xst[:, 1], '+')
+plt.figure(1, (5, 5))
+plt.clf()
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
+plt.plot(xst[:, 0], xst[:, 1], '+')
-pl.show()
+plt.show()
##############################################################################
# Load image data
@@ -94,8 +98,11 @@ def minmax(img):
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
@@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape))
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 7))
+plt.figure(2, figsize=(10, 7))
-pl.subplot(2, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 2, 3)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Mapping Im. 1')
+plt.subplot(2, 2, 3)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Mapping Im. 1')
-pl.subplot(2, 2, 4)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Inverse mapping Im. 2')
+plt.subplot(2, 2, 4)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Inverse mapping Im. 2')
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index 72010a6..dbece70 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
+import os
+from pathlib import Path
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
def im2mat(img):
@@ -48,17 +50,19 @@ def minmax(img):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
nb = 500
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
# Plot original images
# --------------------
-pl.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.figure(1, figsize=(6.4, 3))
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot pixel values distribution
# ------------------------------
-pl.figure(2, figsize=(6.4, 5))
+plt.figure(2, figsize=(6.4, 5))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 5))
+plt.figure(2, figsize=(10, 5))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 3, 2)
-pl.imshow(Image_emd)
-pl.axis('off')
-pl.title('EmdTransport')
+plt.subplot(2, 3, 2)
+plt.imshow(Image_emd)
+plt.axis('off')
+plt.title('EmdTransport')
-pl.subplot(2, 3, 5)
-pl.imshow(Image_sinkhorn)
-pl.axis('off')
-pl.title('SinkhornTransport')
+plt.subplot(2, 3, 5)
+plt.imshow(Image_sinkhorn)
+plt.axis('off')
+plt.title('SinkhornTransport')
-pl.subplot(2, 3, 3)
-pl.imshow(Image_mapping_linear)
-pl.axis('off')
-pl.title('MappingTransport (linear)')
+plt.subplot(2, 3, 3)
+plt.imshow(Image_mapping_linear)
+plt.axis('off')
+plt.title('MappingTransport (linear)')
-pl.subplot(2, 3, 6)
-pl.imshow(Image_mapping_gaussian)
-pl.axis('off')
-pl.title('MappingTransport (gaussian)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(Image_mapping_gaussian)
+plt.axis('off')
+plt.title('MappingTransport (gaussian)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index e2d88ba..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -89,17 +91,19 @@ def im2mat(img):
return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')