summaryrefslogtreecommitdiff
path: root/examples/demo_OT_1D.py
blob: 29f20743963cc5a0cfaf5fc14172577c7ac3dac2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 21 09:51:45 2016

@author: rflamary
"""

import numpy as np
import matplotlib.pylab as pl
from matplotlib import gridspec
import ot



#%% parameters

n=100 # nb bins

ma=20 # mean of a
mb=60 # mean of b

sa=20 # std of a
sb=60 # std of b

# bin positions
x=np.arange(n,dtype=np.float64)

# Gaussian distributions
a=np.exp(-(x-ma)**2/(2*sa^2))
b=np.exp(-(x-mb)**2/(2*sb^2))

# normalization
a/=a.sum()
b/=b.sum()

# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
M/=M.max()

#%% plot the distributions

pl.figure(1)

pl.plot(x,a,'b',label='Source distribution')
pl.plot(x,b,'r',label='Target distribution')

pl.legend()

#%% plot distributions and loss matrix

pl.figure(2)
gs = gridspec.GridSpec(3, 3)

ax1=pl.subplot(gs[0,1:])
pl.plot(x,b,'r',label='Target distribution')
pl.yticks(())

#pl.axis('off')

ax2=pl.subplot(gs[1:,0])
pl.plot(a,x,'b',label='Source distribution')
pl.gca().invert_xaxis()
pl.gca().invert_yaxis()
pl.xticks(())
#pl.ylim((0,n))
#pl.axis('off')

pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
pl.imshow(M,interpolation='nearest')

pl.xlim((0,n))
#pl.ylim((0,n))
#pl.axis('off')

#%% EMD

G0=ot.emd(a,b,M)

#%% plot EMD optimal tranport matrix
pl.figure(3)
gs = gridspec.GridSpec(3, 3)

ax1=pl.subplot(gs[0,1:])
pl.plot(x,b,'r',label='Target distribution')
pl.yticks(())

#pl.axis('off')

ax2=pl.subplot(gs[1:,0])
pl.plot(a,x,'b',label='Source distribution')
pl.gca().invert_xaxis()
pl.gca().invert_yaxis()
pl.xticks(())
#pl.ylim((0,n))
#pl.axis('off')

pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
pl.imshow(G0,interpolation='nearest')

pl.xlim((0,n))
#pl.ylim((0,n))
#pl.axis('off')