summaryrefslogtreecommitdiff
path: root/scripts/benchmark/plot.py
diff options
context:
space:
mode:
authorWitold Baryluk <witold.baryluk+github@gmail.com>2020-10-05 12:11:17 +0000
committerGitHub <noreply@github.com>2020-10-05 12:11:17 +0000
commitea199c3469ed2326486f0214cb22b1bfef532a90 (patch)
tree42e26bd96867d0b68e5aa60e6a6f80da5c445d8c /scripts/benchmark/plot.py
parent3462d7fa855577064ccdee6e4527ce8a52645d5c (diff)
Allow single graph / subplot on plot
`plt.subplots` tries to be special, and return array or not-array depending on a number of subplots. It is not actually helpful, and IMHO bad design. Make it always `ndarray`. The `and not type(axes) is np.ndarray`, is just in case matplotlib decides to make their behavior more uniform. For now work around it. Also, no need for `ndarray.flat` really. Confirmed to work with existing benchmarks (i.e. rows=2, cols=3), and with single graphs (rows=1, cols=1).
Diffstat (limited to 'scripts/benchmark/plot.py')
-rw-r--r--scripts/benchmark/plot.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/scripts/benchmark/plot.py b/scripts/benchmark/plot.py
index 6337b78f..b0b63df3 100644
--- a/scripts/benchmark/plot.py
+++ b/scripts/benchmark/plot.py
@@ -10,6 +10,7 @@ import matplotlib
matplotlib.use('Agg')
from matplotlib import rcParams
import matplotlib.pyplot as plt
+import numpy as np
# Colors
BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1
@@ -24,7 +25,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
x_keys, y_keys, titles, x_labels, y_labels,
label_names, title, tight_plot, verbose):
assert len(results) == num_rows * num_cols
- assert len(results) != 1
+ assert len(results) >= 1
assert len(x_keys) == len(results)
assert len(y_keys) == len(results)
assert len(titles) == len(results)
@@ -64,6 +65,9 @@ def plot_graphs(results, file_name, num_rows, num_cols,
size_y = plot_size * num_rows
rcParams.update({'font.size': font_size})
fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k')
+ if len(results) == 1 and not type(axes) is np.ndarray:
+ axes = np.full((1,1), axes)
+ assert type(axes) is np.ndarray
fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title)
plt.subplots_adjust(wspace=w_space, hspace=h_space)
@@ -72,7 +76,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
for col in range(num_cols):
index = row * num_cols + col
result = results[index]
- ax = axes.flat[index]
+ ax = axes[row, col]
plt.sca(ax)
print("[plot] Plotting subplot %d" % index)