diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-08-09 21:06:04 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-08-09 21:06:04 +0200 |
commit | 7da6492b36cae7ba8859cd6d6ab3250e11f9a2b8 (patch) | |
tree | ad0899d8bc335aa4c3eb22a4db298ebc1224ce0f /scripts | |
parent | 3f5401d4c8947945c4770fb1dfd354892702195f (diff) |
Improved the speed of the new common-best defaults method for the database generation
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/database/database/defaults.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/scripts/database/database/defaults.py b/scripts/database/database/defaults.py index fca793ea..48693247 100644 --- a/scripts/database/database/defaults.py +++ b/scripts/database/database/defaults.py @@ -85,28 +85,26 @@ def get_common_best(database, group_name): parameter_column_names = [c for c in all_column_names if "parameters." in c] # Removes entries which are not available for all devices - database_common = pd.DataFrame() database_by_parameters = database.groupby(parameter_column_names) - for parameter_values, database_parameters in database_by_parameters: - num_entries = database_parameters.shape[0] - if num_entries == num_devices: - database_common = database_common.append(database_parameters) + database_common = database_by_parameters.filter(lambda x: len(x) == num_devices) # Fall back to another method in case there are no shared entries at all across devices - if database_common.shape[0] == 0: - # print("Skipping: " + str(group_name) + " with devices: " + str(num_devices) + " " + str(database.shape[0])) + if len(database_common) == 0: + # print("[database] Skipping: " + str(group_name) + " with devices: %d %d " % (num_devices, len(database))) return get_smallest_best(database) # Computes the sum of the execution times over the different devices - database_common['time'] = database_common.groupby(parameter_column_names)['time'].transform(sum) + database_common_by_parameters = database_common.groupby(parameter_column_names) + group_times = database_common_by_parameters['time'].transform(sum) + database_common.loc[:, 'group_time'] = group_times # Retrieves the entries with the best execution time - best_time = database_common["time"].min() - database_bests = database_common[database_common["time"] == best_time] + best_time = database_common["group_time"].min() + database_bests = database_common[database_common["group_time"] == best_time] # Retrieves one example only (the parameters are the same anyway) - database_bests = database_bests.drop_duplicates(["time"]) - # print(str(group_name) + " with num devices: " + str(num_devices) + " " + str(database_bests.shape)) - assert database_bests.shape[0] == 1 + database_bests = database_bests.drop_duplicates(["group_time"]) + # print("[database] " + str(group_name) + " with devices: " + str(num_devices) + " " + str(database_bests.shape)) + assert len(database_bests) == 1 return database_bests |