summaryrefslogtreecommitdiff
path: root/examples/da/plot_otda_color_images.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/da/plot_otda_color_images.py')
-rw-r--r--examples/da/plot_otda_color_images.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/examples/da/plot_otda_color_images.py b/examples/da/plot_otda_color_images.py
index 805d0b0..3984afb 100644
--- a/examples/da/plot_otda_color_images.py
+++ b/examples/da/plot_otda_color_images.py
@@ -22,7 +22,8 @@ from scipy import ndimage
import matplotlib.pylab as pl
import ot
-np.random.seed(42)
+
+r = np.random.RandomState(42)
def im2mat(I):
@@ -39,6 +40,10 @@ def minmax(I):
return np.clip(I, 0, 1)
+##############################################################################
+# generate data
+##############################################################################
+
# Loading images
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
@@ -48,12 +53,17 @@ X2 = im2mat(I2)
# training samples
nb = 1000
-idx1 = np.random.randint(X1.shape[0], size=(nb,))
-idx2 = np.random.randint(X2.shape[0], size=(nb,))
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+##############################################################################
+
# EMDTransport
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)
@@ -75,6 +85,7 @@ I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
##############################################################################
# plot original image
##############################################################################
@@ -91,6 +102,7 @@ pl.imshow(I2)
pl.axis('off')
pl.title('Image 2')
+
##############################################################################
# scatter plot of colors
##############################################################################
@@ -112,6 +124,7 @@ pl.ylabel('Blue')
pl.title('Image 2')
pl.tight_layout()
+
##############################################################################
# plot new images
##############################################################################