summaryrefslogtreecommitdiff
path: root/ot/plot.py
blob: f78daf609e928491313972776da7777d2de7589a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import matplotlib.pylab as pl
from matplotlib import gridspec


def otplot1D(a,b,M,title=''):
    """ Plot matrix M  with the source and target 1D distribution """
    
    na=M.shape[0]
    nb=M.shape[1]
    
    gs = gridspec.GridSpec(3, 3)
    
    
    xa=np.arange(na)
    xb=np.arange(nb)
    
    
    ax1=pl.subplot(gs[0,1:])
    pl.plot(xb,b,'r',label='Target distribution')
    pl.yticks(())
    pl.title(title)
    
    #pl.axis('off')
    
    ax2=pl.subplot(gs[1:,0])
    pl.plot(a,xa,'b',label='Source distribution')
    pl.gca().invert_xaxis()
    pl.gca().invert_yaxis()
    pl.xticks(())
    #pl.ylim((0,n))
    #pl.axis('off')
    
    pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
    pl.imshow(M,interpolation='nearest')
    
    pl.xlim((0,nb))


def otplot2D_samples(xs,xt,G,thr=1e-8,**kwargs):
    """ Plot matrix M  in 2D with  lines using alpha values"""
    if ('color' not in kwargs) and ('c' not  in kwargs):
        kwargs['color']='k'
    mx=G.max()
    for i in range(xs.shape[0]):
        for j in range(xt.shape[0]):
            if G[i,j]/mx>thr:
                pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs)