summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-11 21:33:13 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2017-07-20 14:05:12 +0200
commit95b2a584d02da1a08e71f7ff3895d958e42ed2dc (patch)
tree45d22dd460a25e422ec975fc4fe48318222d0da6
parentd0258f103946eb78f2fff6a3a82d85744ba27ec8 (diff)
pep8 + pimp plot1D_mat rendering
-rw-r--r--examples/plot_OT_1D.py43
-rw-r--r--ot/plot.py116
2 files changed, 75 insertions, 84 deletions
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 6661aa3..b36fa6a 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -8,49 +8,50 @@
"""
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pylab as plt
import ot
from ot.datasets import get_1D_gauss as gauss
-
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
-b=gauss(n,m=60,s=10)
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
#%% plot the distributions
-pl.figure(1)
-pl.plot(x,a,'b',label='Source distribution')
-pl.plot(x,b,'r',label='Target distribution')
-pl.legend()
+plt.figure(1)
+plt.plot(x, a, 'b', label='Source distribution')
+plt.plot(x, b, 'r', label='Target distribution')
+plt.legend()
#%% plot distributions and loss matrix
-pl.figure(2)
-ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
+plt.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0=ot.emd(a,b,M)
+G0 = ot.emd(a, b, M)
-pl.figure(3)
-ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+plt.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
#%% Sinkhorn
-lambd=1e-3
-Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
+lambd = 1e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+plt.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
-pl.figure(4)
-ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
+plt.show()
diff --git a/ot/plot.py b/ot/plot.py
index 9737f8a..6f01731 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -4,89 +4,79 @@ Functions for plotting OT matrices
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pylab as plt
from matplotlib import gridspec
-def plot1D_mat(a,b,M,title=''):
- """ Plot matrix M with the source and target 1D distribution
-
- Creates a subplot with the source distribution a on the left and
+def plot1D_mat(a, b, M, title=''):
+ """ Plot matrix M with the source and target 1D distribution
+
+ Creates a subplot with the source distribution a on the left and
target distribution b on the tot. The matrix M is shown in between.
-
-
+
+
Parameters
----------
-
- a : np.array (na,)
+ a : np.array, shape (na,)
Source distribution
- b : np.array (nb,)
- Target distribution
- M : np.array (na,nb)
+ b : np.array, shape (nb,)
+ Target distribution
+ M : np.array, shape (na,nb)
Matrix to plot
-
-
-
"""
-
- na=M.shape[0]
- nb=M.shape[1]
-
+ na, nb = M.shape
+
gs = gridspec.GridSpec(3, 3)
-
-
- xa=np.arange(na)
- xb=np.arange(nb)
-
-
- ax1=pl.subplot(gs[0,1:])
- pl.plot(xb,b,'r',label='Target distribution')
- pl.yticks(())
- pl.title(title)
-
- #pl.axis('off')
-
- ax2=pl.subplot(gs[1:,0])
- pl.plot(a,xa,'b',label='Source distribution')
- pl.gca().invert_xaxis()
- pl.gca().invert_yaxis()
- pl.xticks(())
- #pl.ylim((0,n))
- #pl.axis('off')
-
- pl.subplot(gs[1:,1:],sharex=ax1,sharey=ax2)
- pl.imshow(M,interpolation='nearest')
-
- pl.xlim((0,nb))
-
-
-def plot2D_samples_mat(xs,xt,G,thr=1e-8,**kwargs):
+
+ xa = np.arange(na)
+ xb = np.arange(nb)
+
+ ax1 = plt.subplot(gs[0, 1:])
+ plt.plot(xb, b, 'r', label='Target distribution')
+ plt.yticks(())
+ plt.title(title)
+
+ ax2 = plt.subplot(gs[1:, 0])
+ plt.plot(a, xa, 'b', label='Source distribution')
+ plt.gca().invert_xaxis()
+ plt.gca().invert_yaxis()
+ plt.xticks(())
+
+ plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
+ plt.imshow(M, interpolation='nearest')
+ plt.axis('off')
+
+ plt.xlim((0, nb))
+ plt.tight_layout()
+ plt.subplots_adjust(wspace=0., hspace=0.2)
+
+
+def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
""" Plot matrix M in 2D with lines using alpha values
-
- Plot lines between source and target 2D samples with a color
+
+ Plot lines between source and target 2D samples with a color
proportional to the value of the matrix G between samples.
-
-
+
+
Parameters
----------
-
- xs : np.array (ns,2)
+ xs : ndarray, shape (ns,2)
Source samples positions
- b : np.array (nt,2)
+ b : ndarray, shape (nt,2)
Target samples positions
- G : np.array (na,nb)
+ G : ndarray, shape (na,nb)
OT matrix
thr : float, optional
threshold above which the line is drawn
**kwargs : dict
- paameters given to the plot functions (default color is black if nothing given)
-
+ paameters given to the plot functions (default color is black if
+ nothing given)
"""
- if ('color' not in kwargs) and ('c' not in kwargs):
- kwargs['color']='k'
- mx=G.max()
+ if ('color' not in kwargs) and ('c' not in kwargs):
+ kwargs['color'] = 'k'
+ mx = G.max()
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
- if G[i,j]/mx>thr:
- pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs)
- \ No newline at end of file
+ if G[i, j] / mx > thr:
+ plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
+ alpha=G[i, j] / mx, **kwargs)