diff options
Diffstat (limited to 'examples/da/plot_otda_color_images.py')
-rw-r--r-- | examples/da/plot_otda_color_images.py | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/examples/da/plot_otda_color_images.py b/examples/da/plot_otda_color_images.py index 805d0b0..3984afb 100644 --- a/examples/da/plot_otda_color_images.py +++ b/examples/da/plot_otda_color_images.py @@ -22,7 +22,8 @@ from scipy import ndimage import matplotlib.pylab as pl import ot -np.random.seed(42) + +r = np.random.RandomState(42) def im2mat(I): @@ -39,6 +40,10 @@ def minmax(I): return np.clip(I, 0, 1) +############################################################################## +# generate data +############################################################################## + # Loading images I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 @@ -48,12 +53,17 @@ X2 = im2mat(I2) # training samples nb = 1000 -idx1 = np.random.randint(X1.shape[0], size=(nb,)) -idx2 = np.random.randint(X2.shape[0], size=(nb,)) +idx1 = r.randint(X1.shape[0], size=(nb,)) +idx2 = r.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] + +############################################################################## +# Instantiate the different transport algorithms and fit them +############################################################################## + # EMDTransport ot_emd = ot.da.EMDTransport() ot_emd.fit(Xs=Xs, Xt=Xt) @@ -75,6 +85,7 @@ I2t = minmax(mat2im(transp_Xt_emd, I2.shape)) I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape)) I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) + ############################################################################## # plot original image ############################################################################## @@ -91,6 +102,7 @@ pl.imshow(I2) pl.axis('off') pl.title('Image 2') + ############################################################################## # scatter plot of colors ############################################################################## @@ -112,6 +124,7 @@ pl.ylabel('Blue') pl.title('Image 2') pl.tight_layout() + ############################################################################## # plot new images ############################################################################## |