summaryrefslogtreecommitdiff
path: root/examples/domain-adaptation/plot_otda_mapping_colors_images.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/domain-adaptation/plot_otda_mapping_colors_images.py')
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py118
1 files changed, 61 insertions, 57 deletions
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index 72010a6..dbece70 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
+import os
+from pathlib import Path
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
def im2mat(img):
@@ -48,17 +50,19 @@ def minmax(img):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
nb = 500
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
# Plot original images
# --------------------
-pl.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.figure(1, figsize=(6.4, 3))
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot pixel values distribution
# ------------------------------
-pl.figure(2, figsize=(6.4, 5))
+plt.figure(2, figsize=(6.4, 5))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 5))
+plt.figure(2, figsize=(10, 5))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 3, 2)
-pl.imshow(Image_emd)
-pl.axis('off')
-pl.title('EmdTransport')
+plt.subplot(2, 3, 2)
+plt.imshow(Image_emd)
+plt.axis('off')
+plt.title('EmdTransport')
-pl.subplot(2, 3, 5)
-pl.imshow(Image_sinkhorn)
-pl.axis('off')
-pl.title('SinkhornTransport')
+plt.subplot(2, 3, 5)
+plt.imshow(Image_sinkhorn)
+plt.axis('off')
+plt.title('SinkhornTransport')
-pl.subplot(2, 3, 3)
-pl.imshow(Image_mapping_linear)
-pl.axis('off')
-pl.title('MappingTransport (linear)')
+plt.subplot(2, 3, 3)
+plt.imshow(Image_mapping_linear)
+plt.axis('off')
+plt.title('MappingTransport (linear)')
-pl.subplot(2, 3, 6)
-pl.imshow(Image_mapping_gaussian)
-pl.axis('off')
-pl.title('MappingTransport (gaussian)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(Image_mapping_gaussian)
+plt.axis('off')
+plt.title('MappingTransport (gaussian)')
+plt.tight_layout()
-pl.show()
+plt.show()