diff options
author | Hicham Janati <hicham.janati100@gmail.com> | 2021-11-03 08:41:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-03 08:41:35 +0100 |
commit | e1b67c641da3b3e497db6811af2c200022b10302 (patch) | |
tree | 44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples | |
parent | 61340d526702616ff000d9e1cf71f52dd199a103 (diff) |
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic
* add debiased arg in tests
* add 1d and 2d examples of debiased barycenters
* fix doctest
* fix flake8
* pep8 + make func private + add convergence warnings
* remove rel paths + add rng + pylab to pyplot
* fix stopping criterion debiased
* pass alex
* change params with new API
* add logdomain barycenters + separate debiased API
* test new API
* fix jax read-only ?
* raise error for jax
* test catch jax error
* fix pytest catch error
* fix relative path
* fix flake8
* add warn arg everywhere
* fix ref number
* catch warnings in tests
* add contrib to readme + change ref number
* fix convolution example + gallery thumbnails
* increase coverage
* fix flake
Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/barycenters/plot_barycenter_1D.py | 63 | ||||
-rw-r--r-- | examples/barycenters/plot_barycenter_lp_vs_entropic.py | 2 | ||||
-rw-r--r-- | examples/barycenters/plot_convolutional_barycenter.py | 53 | ||||
-rw-r--r-- | examples/barycenters/plot_debiased_barycenter.py | 131 | ||||
-rw-r--r-- | examples/domain-adaptation/plot_otda_color_images.py | 118 | ||||
-rw-r--r-- | examples/domain-adaptation/plot_otda_linear_mapping.py | 73 | ||||
-rw-r--r-- | examples/domain-adaptation/plot_otda_mapping_colors_images.py | 118 | ||||
-rwxr-xr-x | examples/gromov/plot_gromov_barycenter.py | 90 |
8 files changed, 390 insertions, 258 deletions
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 63dc460..2373e99 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138. # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pylab as pl +import matplotlib.pyplot as plt import ot # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa @@ -51,18 +51,6 @@ M = ot.utils.dist0(n) M /= M.max() ############################################################################## -# Plot data -# --------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') -pl.tight_layout() - -############################################################################## # Barycenter computation # ---------------------- @@ -78,24 +66,20 @@ bary_l2 = A.dot(weights) reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) -pl.figure(2) -pl.clf() -pl.subplot(2, 1, 1) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') +f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1) +ax1.plot(x, A, color="black") +ax1.set_title('Distributions') -pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Wasserstein') -pl.legend() -pl.title('Barycenters') -pl.tight_layout() +ax2.plot(x, bary_l2, 'r', label='l2') +ax2.plot(x, bary_wass, 'g', label='Wasserstein') +ax2.set_title('Barycenters') + +plt.legend() +plt.show() ############################################################################## # Barycentric interpolation # ------------------------- - #%% barycenter interpolation n_alpha = 11 @@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha)) B_wass = np.copy(B_l2) -for i in range(0, n_alpha): +for i in range(n_alpha): alpha = alpha_list[i] weights = np.array([1 - alpha, alpha]) B_l2[:, i] = A.dot(weights) B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation +plt.figure(2) -pl.figure(3) - -cmap = pl.cm.get_cmap('viridis') +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with l2') -pl.tight_layout() +plt.title('Barycenter interpolation with l2') +plt.tight_layout() -pl.figure(4) -cmap = pl.cm.get_cmap('viridis') +plt.figure(3) +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with Wasserstein') -pl.tight_layout() +plt.title('Barycenter interpolation with Wasserstein') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index 57a6bac..6502f16 100644 --- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================================================================= -1D Wasserstein barycenter comparison between exact LP and entropic regularization +1D Wasserstein barycenter: exact LP vs entropic regularization ================================================================================= This example illustrates the computation of regularized Wasserstein Barycenter diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index cbcd4a1..3721f31 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -6,17 +6,18 @@ Convolutional Wasserstein Barycenter example ============================================ -This example is designed to illustrate how the Convolutional Wasserstein Barycenter -function of POT works. +This example is designed to illustrate how the Convolutional Wasserstein +Barycenter function of POT works. """ # Author: Nicolas Courty <ncourty@irisa.fr> # # License: MIT License - +import os +from pathlib import Path import numpy as np -import pylab as pl +import matplotlib.pyplot as plt import ot ############################################################################## @@ -25,22 +26,19 @@ import ot # # The four distributions are constructed from 4 simple images +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) f4 = f4 / np.sum(f4) -A.append(f1) -A.append(f2) -A.append(f3) -A.append(f4) -A = np.array(A) +A = np.array([f1, f2, f3, f4]) nb_images = 5 @@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1)) # ---------------------------------------- # -pl.figure(figsize=(10, 10)) -pl.title('Convolutional Wasserstein Barycenters in POT') +fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) +plt.suptitle('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -74,19 +71,19 @@ for i in range(nb_images): weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - pl.imshow(f1, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f1, cmap=cm) elif i == 0 and j == (nb_images - 1): - pl.imshow(f3, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f3, cmap=cm) elif i == (nb_images - 1) and j == 0: - pl.imshow(f2, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f2, cmap=cm) elif i == (nb_images - 1) and j == (nb_images - 1): - pl.imshow(f4, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f4, cmap=cm) else: # call to barycenter computation - pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - pl.axis('off') -pl.show() + axes[i, j].imshow( + ot.bregman.convolutional_barycenter2d(A, reg, weights), + cmap=cm + ) + axes[i, j].axis('off') +plt.tight_layout() +plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py new file mode 100644 index 0000000..2a603dd --- /dev/null +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +""" +================================= +Debiased Sinkhorn barycenter demo +================================= + +This example illustrates the computation of the debiased Sinkhorn barycenter +as proposed in [37]_. + + +.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 +""" + +# Author: Hicham Janati <hicham.janati100@gmail.com> +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +import os +from pathlib import Path + +import numpy as np +import matplotlib.pyplot as plt + +import ot +from ot.bregman import (barycenter, barycenter_debiased, + convolutional_barycenter2d, + convolutional_barycenter2d_debiased) + +############################################################################## +# Debiased barycenter of 1D Gaussians +# ------------------------------------ + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +#%% barycenter computation + +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +epsilons = [5e-3, 1e-2, 5e-2] + + +bars = [barycenter(A, M, reg, weights) for reg in epsilons] +bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons] +labels = ["Sinkhorn barycenter", "Debiased barycenter"] +colors = ["indianred", "gold"] + +f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, + figsize=(12, 4), num=1) +for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): + ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) + ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) + for data, label, color in zip([bar, bar_debiased], labels, colors): + ax.plot(data, color=color, label=label, lw=2) + ax.set_title(r"$\varepsilon = %.3f$" % eps) +plt.legend() +plt.show() + + +############################################################################## +# Debiased barycenter of 2D images +# --------------------------------- +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] + +A = np.asarray([f1, f2]) + 1e-2 +A /= A.sum(axis=(1, 2))[:, None, None] + +############################################################################## +# Display the input images + +fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2) +for ax, img in zip(axes, A): + ax.imshow(img, cmap="Greys") + ax.axis("off") +fig.tight_layout() +plt.show() + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +bars_sinkhorn, bars_debiased = [], [] +epsilons = [5e-3, 7e-3, 1e-2] +for eps in epsilons: + bar = convolutional_barycenter2d(A, eps) + bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True) + bars_sinkhorn.append(bar) + bars_debiased.append(bar_debiased) + +titles = ["Sinkhorn", "Debiased"] +all_bars = [bars_sinkhorn, bars_debiased] +fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3) +for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): + for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): + ax.imshow(img, cmap="Greys") + if jj == 0: + ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13) + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + if ii == 0: + ax.set_ylabel(method, fontsize=15) +fig.tight_layout() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 6218b13..06dc8ab 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882. # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path + import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -46,16 +49,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -65,39 +71,39 @@ Xt = X2[idx2, :] # Plot original image # ------------------- -pl.figure(1, figsize=(6.4, 3)) +plt.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') ############################################################################## # Scatter plot of colors # ---------------------- -pl.figure(2, figsize=(6.4, 3)) +plt.figure(2, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## @@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) # Plot new images # --------------- -pl.figure(3, figsize=(8, 4)) +plt.figure(3, figsize=(8, 4)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(2, 3, 2) -pl.imshow(I1t) -pl.axis('off') -pl.title('Image 1 Adapt') +plt.subplot(2, 3, 2) +plt.imshow(I1t) +plt.axis('off') +plt.title('Image 1 Adapt') -pl.subplot(2, 3, 3) -pl.imshow(I1te) -pl.axis('off') -pl.title('Image 1 Adapt (reg)') +plt.subplot(2, 3, 3) +plt.imshow(I1te) +plt.axis('off') +plt.title('Image 1 Adapt (reg)') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') -pl.subplot(2, 3, 5) -pl.imshow(I2t) -pl.axis('off') -pl.title('Image 2 Adapt') +plt.subplot(2, 3, 5) +plt.imshow(I2t) +plt.axis('off') +plt.title('Image 2 Adapt') -pl.subplot(2, 3, 6) -pl.imshow(I2te) -pl.axis('off') -pl.title('Image 2 Adapt (reg)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(I2te) +plt.axis('off') +plt.title('Image 2 Adapt (reg)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index be47510..a44096a 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -13,9 +13,11 @@ Linear OT mapping estimation # License: MIT License # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path import numpy as np -import pylab as pl +from matplotlib import pyplot as plt import ot ############################################################################## @@ -26,17 +28,19 @@ n = 1000 d = 2 sigma = .1 +rng = np.random.RandomState(42) + # source samples -angles = np.random.rand(n, 1) * 2 * np.pi +angles = rng.rand(n, 1) * 2 * np.pi xs = np.concatenate((np.sin(angles), np.cos(angles)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xs[:n // 2, 1] += 2 # target samples -anglet = np.random.rand(n, 1) * 2 * np.pi +anglet = rng.rand(n, 1) * 2 * np.pi xt = np.concatenate((np.sin(anglet), np.cos(anglet)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xt[:n // 2, 1] += 2 @@ -48,9 +52,9 @@ xt = xt.dot(A) + b # Plot data # --------- -pl.figure(1, (5, 5)) -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') +plt.figure(1, (5, 5)) +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') ############################################################################## @@ -66,13 +70,13 @@ xst = xs.dot(Ae) + be # Plot transported samples # ------------------------ -pl.figure(1, (5, 5)) -pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') -pl.plot(xst[:, 0], xst[:, 1], '+') +plt.figure(1, (5, 5)) +plt.clf() +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') +plt.plot(xst[:, 0], xst[:, 1], '+') -pl.show() +plt.show() ############################################################################## # Load image data @@ -94,8 +98,11 @@ def minmax(img): # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) @@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape)) # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 7)) +plt.figure(2, figsize=(10, 7)) -pl.subplot(2, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 2, 3) -pl.imshow(I1t) -pl.axis('off') -pl.title('Mapping Im. 1') +plt.subplot(2, 2, 3) +plt.imshow(I1t) +plt.axis('off') +plt.title('Mapping Im. 1') -pl.subplot(2, 2, 4) -pl.imshow(I2t) -pl.axis('off') -pl.title('Inverse mapping Im. 2') +plt.subplot(2, 2, 4) +plt.imshow(I2t) +plt.axis('off') +plt.title('Inverse mapping Im. 2') diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 72010a6..dbece70 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. # License: MIT License # sphinx_gallery_thumbnail_number = 3 +import os +from pathlib import Path import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -48,17 +50,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) # Plot original images # -------------------- -pl.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.figure(1, figsize=(6.4, 3)) +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot pixel values distribution # ------------------------------ -pl.figure(2, figsize=(6.4, 5)) +plt.figure(2, figsize=(6.4, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 5)) +plt.figure(2, figsize=(10, 5)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 3, 2) -pl.imshow(Image_emd) -pl.axis('off') -pl.title('EmdTransport') +plt.subplot(2, 3, 2) +plt.imshow(Image_emd) +plt.axis('off') +plt.title('EmdTransport') -pl.subplot(2, 3, 5) -pl.imshow(Image_sinkhorn) -pl.axis('off') -pl.title('SinkhornTransport') +plt.subplot(2, 3, 5) +plt.imshow(Image_sinkhorn) +plt.axis('off') +plt.title('SinkhornTransport') -pl.subplot(2, 3, 3) -pl.imshow(Image_mapping_linear) -pl.axis('off') -pl.title('MappingTransport (linear)') +plt.subplot(2, 3, 3) +plt.imshow(Image_mapping_linear) +plt.axis('off') +plt.title('MappingTransport (linear)') -pl.subplot(2, 3, 6) -pl.imshow(Image_mapping_gaussian) -pl.axis('off') -pl.title('MappingTransport (gaussian)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(Image_mapping_gaussian) +plt.axis('off') +plt.title('MappingTransport (gaussian)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index e2d88ba..7fe081f 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -13,11 +13,13 @@ computation in POT. #
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -89,17 +91,19 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)] npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
|