1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
# -*- coding: utf-8 -*-
"""
==========================================
2D Optimal transport for different metrics
==========================================
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@author: rflamary
"""
import numpy as np
import matplotlib.pylab as pl
import ot
#%% parameters and data generation
for data in range(2):
if data:
n=20 # nb samples
xs=np.zeros((n,2))
xs[:,0]=np.arange(n)+1
xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
xt=np.zeros((n,2))
xt[:,1]=np.arange(n)+1
else:
n=50 # nb samples
xtot=np.zeros((n+1,2))
xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
xs=xtot[:n,:]
xt=xtot[1:,:]
a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
# loss matrix
M1=ot.dist(xs,xt,metric='euclidean')
M1/=M1.max()
# loss matrix
M2=ot.dist(xs,xt,metric='sqeuclidean')
M2/=M2.max()
# loss matrix
Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
Mp/=Mp.max()
#%% plot samples
pl.figure(1+3*data)
pl.clf()
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.axis('equal')
pl.title('Source and traget distributions')
pl.figure(2+3*data,(15,5))
pl.subplot(1,3,1)
pl.imshow(M1,interpolation='nearest')
pl.title('Eucidean cost')
pl.subplot(1,3,2)
pl.imshow(M2,interpolation='nearest')
pl.title('Squared Euclidean cost')
pl.subplot(1,3,3)
pl.imshow(Mp,interpolation='nearest')
pl.title('Sqrt Euclidean cost')
#%% EMD
G1=ot.emd(a,b,M1)
G2=ot.emd(a,b,M2)
Gp=ot.emd(a,b,Mp)
pl.figure(3+3*data,(15,5))
pl.subplot(1,3,1)
ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.axis('equal')
#pl.legend(loc=0)
pl.title('OT Euclidean')
pl.subplot(1,3,2)
ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.axis('equal')
#pl.legend(loc=0)
pl.title('OT squared Euclidean')
pl.subplot(1,3,3)
ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
pl.axis('equal')
#pl.legend(loc=0)
pl.title('OT sqrt Euclidean')
|