diff options
author | Hicham Janati <hicham.janati100@gmail.com> | 2021-11-03 08:41:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-03 08:41:35 +0100 |
commit | e1b67c641da3b3e497db6811af2c200022b10302 (patch) | |
tree | 44d42e1ae50d653bb07dd6ef9c1de14f71b21642 /examples/domain-adaptation/plot_otda_mapping_colors_images.py | |
parent | 61340d526702616ff000d9e1cf71f52dd199a103 (diff) |
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic
* add debiased arg in tests
* add 1d and 2d examples of debiased barycenters
* fix doctest
* fix flake8
* pep8 + make func private + add convergence warnings
* remove rel paths + add rng + pylab to pyplot
* fix stopping criterion debiased
* pass alex
* change params with new API
* add logdomain barycenters + separate debiased API
* test new API
* fix jax read-only ?
* raise error for jax
* test catch jax error
* fix pytest catch error
* fix relative path
* fix flake8
* add warn arg everywhere
* fix ref number
* catch warnings in tests
* add contrib to readme + change ref number
* fix convolution example + gallery thumbnails
* increase coverage
* fix flake
Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'examples/domain-adaptation/plot_otda_mapping_colors_images.py')
-rw-r--r-- | examples/domain-adaptation/plot_otda_mapping_colors_images.py | 118 |
1 files changed, 61 insertions, 57 deletions
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 72010a6..dbece70 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. # License: MIT License # sphinx_gallery_thumbnail_number = 3 +import os +from pathlib import Path import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -48,17 +50,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) # Plot original images # -------------------- -pl.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.figure(1, figsize=(6.4, 3)) +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot pixel values distribution # ------------------------------ -pl.figure(2, figsize=(6.4, 5)) +plt.figure(2, figsize=(6.4, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 5)) +plt.figure(2, figsize=(10, 5)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 3, 2) -pl.imshow(Image_emd) -pl.axis('off') -pl.title('EmdTransport') +plt.subplot(2, 3, 2) +plt.imshow(Image_emd) +plt.axis('off') +plt.title('EmdTransport') -pl.subplot(2, 3, 5) -pl.imshow(Image_sinkhorn) -pl.axis('off') -pl.title('SinkhornTransport') +plt.subplot(2, 3, 5) +plt.imshow(Image_sinkhorn) +plt.axis('off') +plt.title('SinkhornTransport') -pl.subplot(2, 3, 3) -pl.imshow(Image_mapping_linear) -pl.axis('off') -pl.title('MappingTransport (linear)') +plt.subplot(2, 3, 3) +plt.imshow(Image_mapping_linear) +plt.axis('off') +plt.title('MappingTransport (linear)') -pl.subplot(2, 3, 6) -pl.imshow(Image_mapping_gaussian) -pl.axis('off') -pl.title('MappingTransport (gaussian)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(Image_mapping_gaussian) +plt.axis('off') +plt.title('MappingTransport (gaussian)') +plt.tight_layout() -pl.show() +plt.show() |