summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-08-09 21:06:04 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-08-09 21:06:04 +0200
commit7da6492b36cae7ba8859cd6d6ab3250e11f9a2b8 (patch)
treead0899d8bc335aa4c3eb22a4db298ebc1224ce0f /scripts
parent3f5401d4c8947945c4770fb1dfd354892702195f (diff)
Improved the speed of the new common-best defaults method for the database generation
Diffstat (limited to 'scripts')
-rw-r--r--scripts/database/database/defaults.py24
1 files changed, 11 insertions, 13 deletions
diff --git a/scripts/database/database/defaults.py b/scripts/database/database/defaults.py
index fca793ea..48693247 100644
--- a/scripts/database/database/defaults.py
+++ b/scripts/database/database/defaults.py
@@ -85,28 +85,26 @@ def get_common_best(database, group_name):
parameter_column_names = [c for c in all_column_names if "parameters." in c]
# Removes entries which are not available for all devices
- database_common = pd.DataFrame()
database_by_parameters = database.groupby(parameter_column_names)
- for parameter_values, database_parameters in database_by_parameters:
- num_entries = database_parameters.shape[0]
- if num_entries == num_devices:
- database_common = database_common.append(database_parameters)
+ database_common = database_by_parameters.filter(lambda x: len(x) == num_devices)
# Fall back to another method in case there are no shared entries at all across devices
- if database_common.shape[0] == 0:
- # print("Skipping: " + str(group_name) + " with devices: " + str(num_devices) + " " + str(database.shape[0]))
+ if len(database_common) == 0:
+ # print("[database] Skipping: " + str(group_name) + " with devices: %d %d " % (num_devices, len(database)))
return get_smallest_best(database)
# Computes the sum of the execution times over the different devices
- database_common['time'] = database_common.groupby(parameter_column_names)['time'].transform(sum)
+ database_common_by_parameters = database_common.groupby(parameter_column_names)
+ group_times = database_common_by_parameters['time'].transform(sum)
+ database_common.loc[:, 'group_time'] = group_times
# Retrieves the entries with the best execution time
- best_time = database_common["time"].min()
- database_bests = database_common[database_common["time"] == best_time]
+ best_time = database_common["group_time"].min()
+ database_bests = database_common[database_common["group_time"] == best_time]
# Retrieves one example only (the parameters are the same anyway)
- database_bests = database_bests.drop_duplicates(["time"])
- # print(str(group_name) + " with num devices: " + str(num_devices) + " " + str(database_bests.shape))
- assert database_bests.shape[0] == 1
+ database_bests = database_bests.drop_duplicates(["group_time"])
+ # print("[database] " + str(group_name) + " with devices: " + str(num_devices) + " " + str(database_bests.shape))
+ assert len(database_bests) == 1
return database_bests