summaryrefslogtreecommitdiff
path: root/scripts/benchmark
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/benchmark')
-rw-r--r--scripts/benchmark/plot.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/scripts/benchmark/plot.py b/scripts/benchmark/plot.py
index 6337b78f..b0b63df3 100644
--- a/scripts/benchmark/plot.py
+++ b/scripts/benchmark/plot.py
@@ -10,6 +10,7 @@ import matplotlib
matplotlib.use('Agg')
from matplotlib import rcParams
import matplotlib.pyplot as plt
+import numpy as np
# Colors
BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1
@@ -24,7 +25,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
x_keys, y_keys, titles, x_labels, y_labels,
label_names, title, tight_plot, verbose):
assert len(results) == num_rows * num_cols
- assert len(results) != 1
+ assert len(results) >= 1
assert len(x_keys) == len(results)
assert len(y_keys) == len(results)
assert len(titles) == len(results)
@@ -64,6 +65,9 @@ def plot_graphs(results, file_name, num_rows, num_cols,
size_y = plot_size * num_rows
rcParams.update({'font.size': font_size})
fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k')
+ if len(results) == 1 and not type(axes) is np.ndarray:
+ axes = np.full((1,1), axes)
+ assert type(axes) is np.ndarray
fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title)
plt.subplots_adjust(wspace=w_space, hspace=h_space)
@@ -72,7 +76,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
for col in range(num_cols):
index = row * num_cols + col
result = results[index]
- ax = axes.flat[index]
+ ax = axes[row, col]
plt.sca(ax)
print("[plot] Plotting subplot %d" % index)