summaryrefslogtreecommitdiff
path: root/examples/plot_otda_semi_supervised.py
blob: 8a67720b9b4a25109547c15a4e22454a56f7561d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
"""
============================================
OTDA unsupervised vs semi-supervised setting
============================================

This example introduces a semi supervised domain adaptation in a 2D setting.
It explicits the problem of semi supervised domain adaptation and introduces
some optimal transport approaches to solve it.

Quantities such as optimal couplings, greater coupling coefficients and
transported samples are represented in order to give a visual understanding
of what the transport methods are doing.
"""

# Authors: Remi Flamary <remi.flamary@unice.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import ot


##############################################################################
# Generate data
# -------------

n_samples_source = 150
n_samples_target = 150

Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)


##############################################################################
# Transport source samples onto target samples
# --------------------------------------------


# unsupervised domain adaptation
ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)

# semi-supervised domain adaptation
ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)

# semi supervised DA uses available labaled target samples to modify the cost
# matrix involved in the OT problem. The cost of transporting a source sample
# of class A onto a target sample of class B != A is set to infinite, or a
# very large value

# note that in the present case we consider that all the target samples are
# labeled. For daily applications, some target sample might not have labels,
# in this case the element of yt corresponding to these samples should be
# filled with -1.

# Warning: we recall that -1 cannot be used as a class label


##############################################################################
# Fig 1 : plots source and target samples + matrix of pairwise distance
# ---------------------------------------------------------------------

pl.figure(1, figsize=(10, 10))
pl.subplot(2, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title('Source  samples')

pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title('Target samples')

pl.subplot(2, 2, 3)
pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
pl.xticks([])
pl.yticks([])
pl.title('Cost matrix - unsupervised DA')

pl.subplot(2, 2, 4)
pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
pl.xticks([])
pl.yticks([])
pl.title('Cost matrix - semisupervised DA')

pl.tight_layout()

# the optimal coupling in the semi-supervised DA case will exhibit " shape
# similar" to the cost matrix, (block diagonal matrix)


##############################################################################
# Fig 2 : plots optimal couplings for the different methods
# ---------------------------------------------------------

pl.figure(2, figsize=(8, 4))

pl.subplot(1, 2, 1)
pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nUnsupervised DA')

pl.subplot(1, 2, 2)
pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nSemi-supervised DA')

pl.tight_layout()


##############################################################################
# Fig 3 : plot transported samples
# --------------------------------

# display transported samples
pl.figure(4, figsize=(8, 4))
pl.subplot(1, 2, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
           label='Target samples', alpha=0.5)
pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
           marker='+', label='Transp samples', s=30)
pl.title('Transported samples\nEmdTransport')
pl.legend(loc=0)
pl.xticks([])
pl.yticks([])

pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
           label='Target samples', alpha=0.5)
pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
           marker='+', label='Transp samples', s=30)
pl.title('Transported samples\nSinkhornTransport')
pl.xticks([])
pl.yticks([])

pl.tight_layout()
pl.show()