diff options
Diffstat (limited to 'examples/domain-adaptation/plot_otda_color_images.py')
-rw-r--r-- | examples/domain-adaptation/plot_otda_color_images.py | 118 |
1 files changed, 62 insertions, 56 deletions
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 6218b13..06dc8ab 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882. # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path + import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -46,16 +49,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -65,39 +71,39 @@ Xt = X2[idx2, :] # Plot original image # ------------------- -pl.figure(1, figsize=(6.4, 3)) +plt.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') ############################################################################## # Scatter plot of colors # ---------------------- -pl.figure(2, figsize=(6.4, 3)) +plt.figure(2, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## @@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) # Plot new images # --------------- -pl.figure(3, figsize=(8, 4)) +plt.figure(3, figsize=(8, 4)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(2, 3, 2) -pl.imshow(I1t) -pl.axis('off') -pl.title('Image 1 Adapt') +plt.subplot(2, 3, 2) +plt.imshow(I1t) +plt.axis('off') +plt.title('Image 1 Adapt') -pl.subplot(2, 3, 3) -pl.imshow(I1te) -pl.axis('off') -pl.title('Image 1 Adapt (reg)') +plt.subplot(2, 3, 3) +plt.imshow(I1te) +plt.axis('off') +plt.title('Image 1 Adapt (reg)') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') -pl.subplot(2, 3, 5) -pl.imshow(I2t) -pl.axis('off') -pl.title('Image 2 Adapt') +plt.subplot(2, 3, 5) +plt.imshow(I2t) +plt.axis('off') +plt.title('Image 2 Adapt') -pl.subplot(2, 3, 6) -pl.imshow(I2te) -pl.axis('off') -pl.title('Image 2 Adapt (reg)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(I2te) +plt.axis('off') +plt.title('Image 2 Adapt (reg)') +plt.tight_layout() -pl.show() +plt.show() |