diff options
Diffstat (limited to 'ot/plot.py')
-rw-r--r-- | ot/plot.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/ot/plot.py b/ot/plot.py new file mode 100644 index 0000000..f403e98 --- /dev/null +++ b/ot/plot.py @@ -0,0 +1,91 @@ +""" +Functions for plotting OT matrices + +.. warning:: + Note that by default the module is not import in :mod:`ot`. In order to + use it you need to explicitely import :mod:`ot.plot` + + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +from matplotlib import gridspec + + +def plot1D_mat(a, b, M, title=''): + """ Plot matrix M with the source and target 1D distribution + + Creates a subplot with the source distribution a on the left and + target distribution b on the tot. The matrix M is shown in between. + + + Parameters + ---------- + a : ndarray, shape (na,) + Source distribution + b : ndarray, shape (nb,) + Target distribution + M : ndarray, shape (na, nb) + Matrix to plot + """ + na, nb = M.shape + + gs = gridspec.GridSpec(3, 3) + + xa = np.arange(na) + xb = np.arange(nb) + + ax1 = pl.subplot(gs[0, 1:]) + pl.plot(xb, b, 'r', label='Target distribution') + pl.yticks(()) + pl.title(title) + + ax2 = pl.subplot(gs[1:, 0]) + pl.plot(a, xa, 'b', label='Source distribution') + pl.gca().invert_xaxis() + pl.gca().invert_yaxis() + pl.xticks(()) + + pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2) + pl.imshow(M, interpolation='nearest') + pl.axis('off') + + pl.xlim((0, nb)) + pl.tight_layout() + pl.subplots_adjust(wspace=0., hspace=0.2) + + +def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): + """ Plot matrix M in 2D with lines using alpha values + + Plot lines between source and target 2D samples with a color + proportional to the value of the matrix G between samples. + + + Parameters + ---------- + xs : ndarray, shape (ns,2) + Source samples positions + b : ndarray, shape (nt,2) + Target samples positions + G : ndarray, shape (na,nb) + OT matrix + thr : float, optional + threshold above which the line is drawn + **kwargs : dict + paameters given to the plot functions (default color is black if + nothing given) + """ + if ('color' not in kwargs) and ('c' not in kwargs): + kwargs['color'] = 'k' + mx = G.max() + for i in range(xs.shape[0]): + for j in range(xt.shape[0]): + if G[i, j] / mx > thr: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx, **kwargs) |