summaryrefslogtreecommitdiff
path: root/scripts/benchmark/plot.py
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2017-04-02 14:53:55 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2017-04-02 14:53:55 +0200
commit0f96e9d2f9469f70f016ac00e16f81dfe2f084d4 (patch)
treebc274f903d0f3e1b445cc1b55f911a4ce57f0dc3 /scripts/benchmark/plot.py
parent1ee71fdc8067377d9dad27d8cdae1cec9f0fb475 (diff)
Various tweaks to the new benchmark script
Diffstat (limited to 'scripts/benchmark/plot.py')
-rw-r--r--scripts/benchmark/plot.py81
1 files changed, 41 insertions, 40 deletions
diff --git a/scripts/benchmark/plot.py b/scripts/benchmark/plot.py
index bc9529a6..275a3ba8 100644
--- a/scripts/benchmark/plot.py
+++ b/scripts/benchmark/plot.py
@@ -9,33 +9,6 @@ import utils
from matplotlib import rcParams
import matplotlib.pyplot as plt
-# Tight plot (for in a paper or presentation) or regular (for display on a screen)
-TIGHT_PLOT = False
-if TIGHT_PLOT:
- PLOT_SIZE = 5
- W_SPACE = 0.20
- H_SPACE = 0.39
- TITLE_FROM_TOP = 0.11
- LEGEND_FROM_TOP = 0.17
- LEGEND_FROM_TOP_PER_ITEM = 0.04
- X_LABEL_FROM_BOTTOM = 0.09
- LEGEND_SPACING = 0.0
- FONT_SIZE = 15
- FONT_SIZE_LEGEND = 13
- FONT_SIZE_TITLE = FONT_SIZE
-else:
- PLOT_SIZE = 8
- W_SPACE = 0.15
- H_SPACE = 0.22
- TITLE_FROM_TOP = 0.09
- LEGEND_FROM_TOP = 0.10
- LEGEND_FROM_TOP_PER_ITEM = 0.07
- X_LABEL_FROM_BOTTOM = 0.06
- LEGEND_SPACING = 0.8
- FONT_SIZE = 15
- FONT_SIZE_LEGEND = FONT_SIZE
- FONT_SIZE_TITLE = 18
-
# Colors
BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1
REDISH = [c / 255.0 for c in [214, 117, 104]] # #d67568
@@ -46,7 +19,7 @@ MARKERS = ["o-", "x-", ".-"]
def plot_graphs(results, file_name, num_rows, num_cols,
x_keys, y_keys, titles, x_labels, y_labels,
- label_names, title, verbose):
+ label_names, title, tight_plot, verbose):
assert len(results) == num_rows * num_cols
assert len(results) != 1
assert len(x_keys) == len(results)
@@ -55,13 +28,41 @@ def plot_graphs(results, file_name, num_rows, num_cols,
assert len(x_labels) == len(results)
assert len(y_labels) == len(results)
+ # Tight plot (for in a paper or presentation) or regular (for display on a screen)
+ if tight_plot:
+ plot_size = 5
+ w_space = 0.20
+ h_space = 0.39
+ title_from_top = 0.11
+ legend_from_top = 0.17
+ legend_from_top_per_item = 0.04
+ x_label_from_bottom = 0.09
+ legend_spacing = 0.0
+ font_size = 15
+ font_size_legend = 13
+ font_size_title = font_size
+ bounding_box = "tight"
+ else:
+ plot_size = 8
+ w_space = 0.15
+ h_space = 0.22
+ title_from_top = 0.09
+ legend_from_top = 0.10
+ legend_from_top_per_item = 0.07
+ x_label_from_bottom = 0.06
+ legend_spacing = 0.8
+ font_size = 15
+ font_size_legend = font_size
+ font_size_title = 18
+ bounding_box = None # means not 'tight'
+
# Initializes the plot
- size_x = PLOT_SIZE * num_cols
- size_y = PLOT_SIZE * num_rows
+ size_x = plot_size * num_cols
+ size_y = plot_size * num_rows
fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k')
- fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=FONT_SIZE_TITLE)
- plt.subplots_adjust(wspace=W_SPACE, hspace=H_SPACE)
- rcParams.update({'font.size': FONT_SIZE})
+ fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title)
+ plt.subplots_adjust(wspace=w_space, hspace=h_space)
+ rcParams.update({'font.size': font_size})
# Loops over each subplot
for row in range(num_rows):
@@ -78,7 +79,7 @@ def plot_graphs(results, file_name, num_rows, num_cols,
x_location = range(len(x_ticks))
# Optional sparsifying of the labels on the x-axis
- if TIGHT_PLOT and len(x_location) > 10:
+ if tight_plot and len(x_location) > 10:
x_ticks = [v if not (i % 2) else "" for i, v in enumerate(x_ticks)]
# Sets the y-data
@@ -92,11 +93,11 @@ def plot_graphs(results, file_name, num_rows, num_cols,
plt.xticks(x_location, x_ticks, rotation='vertical')
# Sets the labels
- ax.set_title(titles[index], y=1.0 - TITLE_FROM_TOP)
+ ax.set_title(titles[index], y=1.0 - title_from_top, fontsize=font_size)
if col == 0 or y_labels[index] != y_labels[index - 1]:
ax.set_ylabel(y_labels[index])
ax.set_xlabel(x_labels[index])
- ax.xaxis.set_label_coords(0.5, X_LABEL_FROM_BOTTOM)
+ ax.xaxis.set_label_coords(0.5, x_label_from_bottom)
# Plots the graph
assert len(COLORS) >= len(y_keys[index])
@@ -106,10 +107,10 @@ def plot_graphs(results, file_name, num_rows, num_cols,
ax.plot(x_location, y_list[i], MARKERS[i], label=label_names[i], color=COLORS[i])
# Sets the legend
- leg = ax.legend(loc=(0.02, 1.0 - LEGEND_FROM_TOP - LEGEND_FROM_TOP_PER_ITEM * len(y_keys[index])),
- handletextpad=0.1, labelspacing=LEGEND_SPACING, fontsize=FONT_SIZE_LEGEND)
+ leg = ax.legend(loc=(0.02, 1.0 - legend_from_top - legend_from_top_per_item * len(y_keys[index])),
+ handletextpad=0.1, labelspacing=legend_spacing, fontsize=font_size_legend)
leg.draw_frame(False)
# Saves the plot to disk
- fig.savefig(file_name, bbox_inches='tight')
- plt.show()
+ fig.savefig(file_name, bbox_inches=bounding_box)
+ # plt.show()