From 95b2a584d02da1a08e71f7ff3895d958e42ed2dc Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Tue, 11 Jul 2017 21:33:13 +0200 Subject: pep8 + pimp plot1D_mat rendering --- ot/plot.py | 116 ++++++++++++++++++++++++++++--------------------------------- 1 file changed, 53 insertions(+), 63 deletions(-) (limited to 'ot/plot.py') diff --git a/ot/plot.py b/ot/plot.py index 9737f8a..6f01731 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -4,89 +4,79 @@ Functions for plotting OT matrices import numpy as np -import matplotlib.pylab as pl +import matplotlib.pylab as plt from matplotlib import gridspec -def plot1D_mat(a,b,M,title=''): - """ Plot matrix M with the source and target 1D distribution - - Creates a subplot with the source distribution a on the left and +def plot1D_mat(a, b, M, title=''): + """ Plot matrix M with the source and target 1D distribution + + Creates a subplot with the source distribution a on the left and target distribution b on the tot. The matrix M is shown in between. - - + + Parameters ---------- - - a : np.array (na,) + a : np.array, shape (na,) Source distribution - b : np.array (nb,) - Target distribution - M : np.array (na,nb) + b : np.array, shape (nb,) + Target distribution + M : np.array, shape (na,nb) Matrix to plot - - - """ - - na=M.shape[0] - nb=M.shape[1] - + na, nb = M.shape + gs = gridspec.GridSpec(3, 3) - - - xa=np.arange(na) - xb=np.arange(nb) - - - ax1=pl.subplot(gs[0,1:]) - pl.plot(xb,b,'r',label='Target distribution') - pl.yticks(()) - pl.title(title) - - #pl.axis('off') - - ax2=pl.subplot(gs[1:,0]) - pl.plot(a,xa,'b',label='Source distribution') - pl.gca().invert_xaxis() - pl.gca().invert_yaxis() - pl.xticks(()) - #pl.ylim((0,n)) - #pl.axis('off') - - pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) - pl.imshow(M,interpolation='nearest') - - pl.xlim((0,nb)) - - -def plot2D_samples_mat(xs,xt,G,thr=1e-8,**kwargs): + + xa = np.arange(na) + xb = np.arange(nb) + + ax1 = plt.subplot(gs[0, 1:]) + plt.plot(xb, b, 'r', label='Target distribution') + plt.yticks(()) + plt.title(title) + + ax2 = plt.subplot(gs[1:, 0]) + plt.plot(a, xa, 'b', label='Source distribution') + plt.gca().invert_xaxis() + plt.gca().invert_yaxis() + plt.xticks(()) + + plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2) + plt.imshow(M, interpolation='nearest') + plt.axis('off') + + plt.xlim((0, nb)) + plt.tight_layout() + plt.subplots_adjust(wspace=0., hspace=0.2) + + +def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): """ Plot matrix M in 2D with lines using alpha values - - Plot lines between source and target 2D samples with a color + + Plot lines between source and target 2D samples with a color proportional to the value of the matrix G between samples. - - + + Parameters ---------- - - xs : np.array (ns,2) + xs : ndarray, shape (ns,2) Source samples positions - b : np.array (nt,2) + b : ndarray, shape (nt,2) Target samples positions - G : np.array (na,nb) + G : ndarray, shape (na,nb) OT matrix thr : float, optional threshold above which the line is drawn **kwargs : dict - paameters given to the plot functions (default color is black if nothing given) - + paameters given to the plot functions (default color is black if + nothing given) """ - if ('color' not in kwargs) and ('c' not in kwargs): - kwargs['color']='k' - mx=G.max() + if ('color' not in kwargs) and ('c' not in kwargs): + kwargs['color'] = 'k' + mx = G.max() for i in range(xs.shape[0]): for j in range(xt.shape[0]): - if G[i,j]/mx>thr: - pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs) - \ No newline at end of file + if G[i, j] / mx > thr: + plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx, **kwargs) -- cgit v1.2.3