diff options
author | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2016-12-02 12:49:38 +0100 |
---|---|---|
committer | Alexandre Gramfort <alexandre.gramfort@m4x.org> | 2016-12-02 12:57:25 +0100 |
commit | f439f777084690ecbf54bcd8d67dadc883fffa31 (patch) | |
tree | 56c8a160fb6edcdaca8ca6ce6de1949b9bc33b77 /examples/plot_OTDA_mapping_color_images.py | |
parent | 8dbfd3edae649f5f3e87be4a3ce446c59729b2f7 (diff) |
first attempt to support sphinx-gallery
Diffstat (limited to 'examples/plot_OTDA_mapping_color_images.py')
-rw-r--r-- | examples/plot_OTDA_mapping_color_images.py | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/examples/plot_OTDA_mapping_color_images.py b/examples/plot_OTDA_mapping_color_images.py new file mode 100644 index 0000000..0cd6c9c --- /dev/null +++ b/examples/plot_OTDA_mapping_color_images.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +""" +====================================================================================================================== +Demo of Optimal transport for domain adaptation with image color adaptation as in [6] with mapping estimation from [8] +====================================================================================================================== + +[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized + discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. + +""" + +import numpy as np +import scipy.ndimage as spi +import matplotlib.pylab as pl +import ot + + +#%% Loading images + +I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256 +I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256 + +#%% Plot images + +pl.figure(1) + +pl.subplot(1,2,1) +pl.imshow(I1) +pl.title('Image 1') + +pl.subplot(1,2,2) +pl.imshow(I2) +pl.title('Image 2') + +pl.show() + +#%% Image conversion and dataset generation + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0]*I.shape[1],I.shape[2])) + +def mat2im(X,shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + +X1=im2mat(I1) +X2=im2mat(I2) + +# training samples +nb=1000 +idx1=np.random.randint(X1.shape[0],size=(nb,)) +idx2=np.random.randint(X2.shape[0],size=(nb,)) + +xs=X1[idx1,:] +xt=X2[idx2,:] + +#%% Plot image distributions + + +pl.figure(2,(10,5)) + +pl.subplot(1,2,1) +pl.scatter(xs[:,0],xs[:,2],c=xs) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 1') + +pl.subplot(1,2,2) +#pl.imshow(I2) +pl.scatter(xt[:,0],xt[:,2],c=xt) +pl.axis([0,1,0,1]) +pl.xlabel('Red') +pl.ylabel('Blue') +pl.title('Image 2') + +pl.show() + + + +#%% domain adaptation between images +def minmax(I): + return np.minimum(np.maximum(I,0),1) +# LP problem +da_emd=ot.da.OTDA() # init class +da_emd.fit(xs,xt) # fit distributions + +X1t=da_emd.predict(X1) # out of sample +I1t=minmax(mat2im(X1t,I1.shape)) + +# sinkhorn regularization +lambd=1e-1 +da_entrop=ot.da.OTDA_sinkhorn() +da_entrop.fit(xs,xt,reg=lambd) + +X1te=da_entrop.predict(X1) +I1te=minmax(mat2im(X1te,I1.shape)) + +# linear mapping estimation +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +X1tl=ot_mapping.predict(X1) # use the estimated mapping +I1tl=minmax(mat2im(X1tl,I1.shape)) + +# nonlinear mapping estimation +eta=1e-2 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=False # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping +I1tn=minmax(mat2im(X1tn,I1.shape)) +#%% plot images + + +pl.figure(2,(10,8)) + +pl.subplot(2,3,1) + +pl.imshow(I1) +pl.title('Im. 1') + +pl.subplot(2,3,2) + +pl.imshow(I2) +pl.title('Im. 2') + + +pl.subplot(2,3,3) +pl.imshow(I1t) +pl.title('Im. 1 Interp LP') + +pl.subplot(2,3,4) +pl.imshow(I1te) +pl.title('Im. 1 Interp Entrop') + + +pl.subplot(2,3,5) +pl.imshow(I1tl) +pl.title('Im. 1 Linear mapping') + +pl.subplot(2,3,6) +pl.imshow(I1tn) +pl.title('Im. 1 nonlinear mapping') + +pl.show()
\ No newline at end of file |