summaryrefslogtreecommitdiff
path: root/examples/domain-adaptation/plot_otda_linear_mapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/domain-adaptation/plot_otda_linear_mapping.py')
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py81
1 files changed, 44 insertions, 37 deletions
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index dbf16b8..a44096a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -13,9 +13,11 @@ Linear OT mapping estimation
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+from matplotlib import pyplot as plt
import ot
##############################################################################
@@ -26,17 +28,19 @@ n = 1000
d = 2
sigma = .1
+rng = np.random.RandomState(42)
+
# source samples
-angles = np.random.rand(n, 1) * 2 * np.pi
+angles = rng.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xs[:n // 2, 1] += 2
# target samples
-anglet = np.random.rand(n, 1) * 2 * np.pi
+anglet = rng.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xt[:n // 2, 1] += 2
@@ -48,9 +52,9 @@ xt = xt.dot(A) + b
# Plot data
# ---------
-pl.figure(1, (5, 5))
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
+plt.figure(1, (5, 5))
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
##############################################################################
@@ -66,22 +70,22 @@ xst = xs.dot(Ae) + be
# Plot transported samples
# ------------------------
-pl.figure(1, (5, 5))
-pl.clf()
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
-pl.plot(xst[:, 0], xst[:, 1], '+')
+plt.figure(1, (5, 5))
+plt.clf()
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
+plt.plot(xst[:, 0], xst[:, 1], '+')
-pl.show()
+plt.show()
##############################################################################
# Load image data
# ---------------
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
def mat2im(X, shape):
@@ -89,13 +93,16 @@ def mat2im(X, shape):
return X.reshape(shape)
-def minmax(I):
- return np.clip(I, 0, 1)
+def minmax(img):
+ return np.clip(img, 0, 1)
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
@@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape))
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 7))
+plt.figure(2, figsize=(10, 7))
-pl.subplot(2, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 2, 3)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Mapping Im. 1')
+plt.subplot(2, 2, 3)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Mapping Im. 1')
-pl.subplot(2, 2, 4)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Inverse mapping Im. 2')
+plt.subplot(2, 2, 4)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Inverse mapping Im. 2')