summaryrefslogtreecommitdiff
path: root/ot/plot.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-11 21:33:13 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commit95b2a584d02da1a08e71f7ff3895d958e42ed2dc (patch)
tree45d22dd460a25e422ec975fc4fe48318222d0da6 /ot/plot.py
parentd0258f103946eb78f2fff6a3a82d85744ba27ec8 (diff)
pep8 + pimp plot1D_mat rendering
Diffstat (limited to 'ot/plot.py')
-rw-r--r--ot/plot.py116
1 files changed, 53 insertions, 63 deletions
diff --git a/ot/plot.py b/ot/plot.py
index 9737f8a..6f01731 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -4,89 +4,79 @@ Functions for plotting OT matrices
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pylab as plt
from matplotlib import gridspec
-def plot1D_mat(a,b,M,title=''):
- """ Plot matrix M with the source and target 1D distribution
-
- Creates a subplot with the source distribution a on the left and
+def plot1D_mat(a, b, M, title=''):
+ """ Plot matrix M with the source and target 1D distribution
+
+ Creates a subplot with the source distribution a on the left and
target distribution b on the tot. The matrix M is shown in between.
-
-
+
+
Parameters
----------
-
- a : np.array (na,)
+ a : np.array, shape (na,)
Source distribution
- b : np.array (nb,)
- Target distribution
- M : np.array (na,nb)
+ b : np.array, shape (nb,)
+ Target distribution
+ M : np.array, shape (na,nb)
Matrix to plot
-
-
-
"""
-
- na=M.shape[0]
- nb=M.shape[1]
-
+ na, nb = M.shape
+
gs = gridspec.GridSpec(3, 3)
-
-
- xa=np.arange(na)
- xb=np.arange(nb)
-
-
- ax1=pl.subplot(gs[0,1:])
- pl.plot(xb,b,'r',label='Target distribution')
- pl.yticks(())
- pl.title(title)
-
- #pl.axis('off')
-
- ax2=pl.subplot(gs[1:,0])
- pl.plot(a,xa,'b',label='Source distribution')
- pl.gca().invert_xaxis()
- pl.gca().invert_yaxis()
- pl.xticks(())
- #pl.ylim((0,n))
- #pl.axis('off')
-
- pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
- pl.imshow(M,interpolation='nearest')
-
- pl.xlim((0,nb))
-
-
-def plot2D_samples_mat(xs,xt,G,thr=1e-8,**kwargs):
+
+ xa = np.arange(na)
+ xb = np.arange(nb)
+
+ ax1 = plt.subplot(gs[0, 1:])
+ plt.plot(xb, b, 'r', label='Target distribution')
+ plt.yticks(())
+ plt.title(title)
+
+ ax2 = plt.subplot(gs[1:, 0])
+ plt.plot(a, xa, 'b', label='Source distribution')
+ plt.gca().invert_xaxis()
+ plt.gca().invert_yaxis()
+ plt.xticks(())
+
+ plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
+ plt.imshow(M, interpolation='nearest')
+ plt.axis('off')
+
+ plt.xlim((0, nb))
+ plt.tight_layout()
+ plt.subplots_adjust(wspace=0., hspace=0.2)
+
+
+def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
""" Plot matrix M in 2D with lines using alpha values
-
- Plot lines between source and target 2D samples with a color
+
+ Plot lines between source and target 2D samples with a color
proportional to the value of the matrix G between samples.
-
-
+
+
Parameters
----------
-
- xs : np.array (ns,2)
+ xs : ndarray, shape (ns,2)
Source samples positions
- b : np.array (nt,2)
+ b : ndarray, shape (nt,2)
Target samples positions
- G : np.array (na,nb)
+ G : ndarray, shape (na,nb)
OT matrix
thr : float, optional
threshold above which the line is drawn
**kwargs : dict
- paameters given to the plot functions (default color is black if nothing given)
-
+ paameters given to the plot functions (default color is black if
+ nothing given)
"""
- if ('color' not in kwargs) and ('c' not in kwargs):
- kwargs['color']='k'
- mx=G.max()
+ if ('color' not in kwargs) and ('c' not in kwargs):
+ kwargs['color'] = 'k'
+ mx = G.max()
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
- if G[i,j]/mx>thr:
- pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs)
- \ No newline at end of file
+ if G[i, j] / mx > thr:
+ plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
+ alpha=G[i, j] / mx, **kwargs)