From 95b2a584d02da1a08e71f7ff3895d958e42ed2dc Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Tue, 11 Jul 2017 21:33:13 +0200 Subject: pep8 + pimp plot1D_mat rendering --- examples/plot_OT_1D.py | 43 +++++++++--------- ot/plot.py | 116 ++++++++++++++++++++++--------------------------- 2 files changed, 75 insertions(+), 84 deletions(-) diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 6661aa3..b36fa6a 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -8,49 +8,50 @@ """ import numpy as np -import matplotlib.pylab as pl +import matplotlib.pylab as plt import ot from ot.datasets import get_1D_gauss as gauss - #%% parameters -n=100 # nb bins +n = 100 # nb bins # bin positions -x=np.arange(n,dtype=np.float64) +x = np.arange(n, dtype=np.float64) # Gaussian distributions -a=gauss(n,m=20,s=5) # m= mean, s= std -b=gauss(n,m=60,s=10) +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) # loss matrix -M=ot.dist(x.reshape((n,1)),x.reshape((n,1))) -M/=M.max() +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() #%% plot the distributions -pl.figure(1) -pl.plot(x,a,'b',label='Source distribution') -pl.plot(x,b,'r',label='Target distribution') -pl.legend() +plt.figure(1) +plt.plot(x, a, 'b', label='Source distribution') +plt.plot(x, b, 'r', label='Target distribution') +plt.legend() #%% plot distributions and loss matrix -pl.figure(2) -ot.plot.plot1D_mat(a,b,M,'Cost matrix M') +plt.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% EMD -G0=ot.emd(a,b,M) +G0 = ot.emd(a, b, M) -pl.figure(3) -ot.plot.plot1D_mat(a,b,G0,'OT matrix G0') +plt.figure(3, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') #%% Sinkhorn -lambd=1e-3 -Gs=ot.sinkhorn(a,b,M,lambd,verbose=True) +lambd = 1e-3 +Gs = ot.sinkhorn(a, b, M, lambd, verbose=True) + +plt.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn') -pl.figure(4) -ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn') +plt.show() diff --git a/ot/plot.py b/ot/plot.py index 9737f8a..6f01731 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -4,89 +4,79 @@ Functions for plotting OT matrices import numpy as np -import matplotlib.pylab as pl +import matplotlib.pylab as plt from matplotlib import gridspec -def plot1D_mat(a,b,M,title=''): - """ Plot matrix M with the source and target 1D distribution - - Creates a subplot with the source distribution a on the left and +def plot1D_mat(a, b, M, title=''): + """ Plot matrix M with the source and target 1D distribution + + Creates a subplot with the source distribution a on the left and target distribution b on the tot. The matrix M is shown in between. - - + + Parameters ---------- - - a : np.array (na,) + a : np.array, shape (na,) Source distribution - b : np.array (nb,) - Target distribution - M : np.array (na,nb) + b : np.array, shape (nb,) + Target distribution + M : np.array, shape (na,nb) Matrix to plot - - - """ - - na=M.shape[0] - nb=M.shape[1] - + na, nb = M.shape + gs = gridspec.GridSpec(3, 3) - - - xa=np.arange(na) - xb=np.arange(nb) - - - ax1=pl.subplot(gs[0,1:]) - pl.plot(xb,b,'r',label='Target distribution') - pl.yticks(()) - pl.title(title) - - #pl.axis('off') - - ax2=pl.subplot(gs[1:,0]) - pl.plot(a,xa,'b',label='Source distribution') - pl.gca().invert_xaxis() - pl.gca().invert_yaxis() - pl.xticks(()) - #pl.ylim((0,n)) - #pl.axis('off') - - pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2) - pl.imshow(M,interpolation='nearest') - - pl.xlim((0,nb)) - - -def plot2D_samples_mat(xs,xt,G,thr=1e-8,**kwargs): + + xa = np.arange(na) + xb = np.arange(nb) + + ax1 = plt.subplot(gs[0, 1:]) + plt.plot(xb, b, 'r', label='Target distribution') + plt.yticks(()) + plt.title(title) + + ax2 = plt.subplot(gs[1:, 0]) + plt.plot(a, xa, 'b', label='Source distribution') + plt.gca().invert_xaxis() + plt.gca().invert_yaxis() + plt.xticks(()) + + plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2) + plt.imshow(M, interpolation='nearest') + plt.axis('off') + + plt.xlim((0, nb)) + plt.tight_layout() + plt.subplots_adjust(wspace=0., hspace=0.2) + + +def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): """ Plot matrix M in 2D with lines using alpha values - - Plot lines between source and target 2D samples with a color + + Plot lines between source and target 2D samples with a color proportional to the value of the matrix G between samples. - - + + Parameters ---------- - - xs : np.array (ns,2) + xs : ndarray, shape (ns,2) Source samples positions - b : np.array (nt,2) + b : ndarray, shape (nt,2) Target samples positions - G : np.array (na,nb) + G : ndarray, shape (na,nb) OT matrix thr : float, optional threshold above which the line is drawn **kwargs : dict - paameters given to the plot functions (default color is black if nothing given) - + paameters given to the plot functions (default color is black if + nothing given) """ - if ('color' not in kwargs) and ('c' not in kwargs): - kwargs['color']='k' - mx=G.max() + if ('color' not in kwargs) and ('c' not in kwargs): + kwargs['color'] = 'k' + mx = G.max() for i in range(xs.shape[0]): for j in range(xt.shape[0]): - if G[i,j]/mx>thr: - pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs) - \ No newline at end of file + if G[i, j] / mx > thr: + plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx, **kwargs) -- cgit v1.2.3