summaryrefslogtreecommitdiff
path: root/examples/plot_OT_L1_vs_L2.py
blob: 9bb92fef36943ef35cd808a96cf3d528a518a7ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding: utf-8 -*-
"""
==========================================
2D Optimal transport for different metrics
==========================================

Stole the figure idea from Fig. 1 and 2 in 
https://arxiv.org/pdf/1706.07650.pdf


@author: rflamary
"""

import numpy as np
import matplotlib.pylab as pl
import ot

#%% parameters and data generation

for data in range(2):

    if data:
        n=20 # nb samples
        xs=np.zeros((n,2))
        xs[:,0]=np.arange(n)+1
        xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
        
        xt=np.zeros((n,2))
        xt[:,1]=np.arange(n)+1
    else:
        
        n=50 # nb samples
        xtot=np.zeros((n+1,2))
        xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
        xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
        
        xs=xtot[:n,:]
        xt=xtot[1:,:]
        
        
    
    a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
    
    # loss matrix
    M1=ot.dist(xs,xt,metric='euclidean')
    M1/=M1.max()
    
    # loss matrix
    M2=ot.dist(xs,xt,metric='sqeuclidean')
    M2/=M2.max()
    
    # loss matrix
    Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
    Mp/=Mp.max()
    
    #%% plot samples
    
    pl.figure(1+3*data)
    pl.clf()
    pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
    pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
    pl.axis('equal')
    pl.title('Source and traget distributions')
    
    pl.figure(2+3*data,(15,5))
    pl.subplot(1,3,1)
    pl.imshow(M1,interpolation='nearest')
    pl.title('Eucidean cost')
    pl.subplot(1,3,2)
    pl.imshow(M2,interpolation='nearest')
    pl.title('Squared Euclidean cost')
    
    pl.subplot(1,3,3)
    pl.imshow(Mp,interpolation='nearest')
    pl.title('Sqrt Euclidean cost')
    #%% EMD
    
    G1=ot.emd(a,b,M1)
    G2=ot.emd(a,b,M2)
    Gp=ot.emd(a,b,Mp)
    
    pl.figure(3+3*data,(15,5))
    
    pl.subplot(1,3,1)
    ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
    pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
    pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
    pl.axis('equal')
    #pl.legend(loc=0)
    pl.title('OT Euclidean')
    
    pl.subplot(1,3,2)
    
    ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
    pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
    pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
    pl.axis('equal')
    #pl.legend(loc=0)
    pl.title('OT squared Euclidean')
    
    pl.subplot(1,3,3)
    
    ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
    pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
    pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
    pl.axis('equal')
    #pl.legend(loc=0)
    pl.title('OT sqrt Euclidean')