summaryrefslogtreecommitdiff
path: root/examples/demo_OTDA_color_images.py
blob: 7a08ea08fa2b69b0dfca464da1108da48100b4a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-
"""
demo of Optimal transport for domain adaptation with image color adaptation as in [6]

[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
"""

import numpy as np
import scipy.ndimage as spi
import matplotlib.pylab as pl
import ot


#%% Loading images

I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256

#%% Plot images

pl.figure(1)

pl.subplot(1,2,1)
pl.imshow(I1)
pl.title('Image 1')

pl.subplot(1,2,2)
pl.imshow(I2)
pl.title('Image 2')

pl.show()

#%% Image conversion and dataset generation

def im2mat(I):
    """Converts and image to matrix (one pixel per line)"""
    return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))

def mat2im(X,shape):
    """Converts back a matrix to an image"""
    return X.reshape(shape)

X1=im2mat(I1)
X2=im2mat(I2)

# training samples
nb=1000
idx1=np.random.randint(X1.shape[0],size=(nb,))
idx2=np.random.randint(X2.shape[0],size=(nb,))

xs=X1[idx1,:]
xt=X2[idx2,:]

#%% domain adaptation between images

# LP problem
da_emd=ot.da.OTDA()     # init class
da_emd.fit(xs,xt)       # fit distributions


# sinkhorn regularization
lambd=1e-1
da_entrop=ot.da.OTDA_sinkhorn()
da_entrop.fit(xs,xt,reg=lambd)



#%% prediction between images (using out of sample prediction as in [6])

X1t=da_emd.predict(X1)
X2t=da_emd.predict(X2,-1)


X1te=da_entrop.predict(X1)
X2te=da_entrop.predict(X2,-1)


def minmax(I):
    return np.minimum(np.maximum(I,0),1)

I1t=minmax(mat2im(X1t,I1.shape))
I2t=minmax(mat2im(X2t,I2.shape))

I1te=minmax(mat2im(X1te,I1.shape))
I2te=minmax(mat2im(X2te,I2.shape))

#%% plot all images

pl.figure(2,(10,8))

pl.subplot(2,3,1)

pl.imshow(I1)
pl.title('Image 1')

pl.subplot(2,3,2)
pl.imshow(I1t)
pl.title('Image 1 Adapt')


pl.subplot(2,3,3)
pl.imshow(I1te)
pl.title('Image 1 Adapt (reg)')

pl.subplot(2,3,4)

pl.imshow(I2)
pl.title('Image 2')

pl.subplot(2,3,5)
pl.imshow(I2t)
pl.title('Image 2 Adapt')


pl.subplot(2,3,6)
pl.imshow(I2te)
pl.title('Image 2 Adapt (reg)')

pl.show()