summaryrefslogtreecommitdiff
path: root/examples/plot_OTDA_mapping.py
blob: 78b57e710a2c4de1676235e0a4babd6551f0273f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""
===============================================
OT mapping estimation for domain adaptation [8]
===============================================

[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
    discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
"""

import numpy as np
import matplotlib.pylab as pl
import ot



#%% dataset generation

np.random.seed(0) # makes example reproducible

n=100 # nb samples in source and target datasets
theta=2*np.pi/20
nz=0.1
xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)

# one of the target mode changes its variance (no linear mapping)
xt[yt==2]*=3
xt=xt+4


#%% plot samples

pl.figure(1,(8,5))
pl.clf()

pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')

pl.legend(loc=0)
pl.title('Source and target distributions')



#%% OT linear mapping estimation

eta=1e-8   # quadratic regularization for regression
mu=1e0     # weight of the OT linear term
bias=True  # estimate a bias

ot_mapping=ot.da.OTDA_mapping_linear()
ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)

xst=ot_mapping.predict(xs) # use the estimated mapping
xst0=ot_mapping.interp()   # use barycentric mapping


pl.figure(2,(10,7))
pl.clf()
pl.subplot(2,2,1)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
pl.title("barycentric mapping")

pl.subplot(2,2,2)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
pl.title("Learned mapping")



#%% Kernel mapping estimation

eta=1e-5   # quadratic regularization for regression
mu=1e-1     # weight of the OT linear term
bias=True  # estimate a bias
sigma=1    # sigma bandwidth fot gaussian kernel


ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)

xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
xst0_kernel=ot_mapping_kernel.interp()   # use barycentric mapping


#%% Plotting the mapped samples

pl.figure(2,(10,7))
pl.clf()
pl.subplot(2,2,1)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
pl.title("Bary. mapping (linear)")
pl.legend(loc=0)

pl.subplot(2,2,2)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
pl.title("Estim. mapping (linear)")

pl.subplot(2,2,3)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
pl.title("Bary. mapping (kernel)")

pl.subplot(2,2,4)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
pl.title("Estim. mapping (kernel)")