summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2020-12-22 15:39:15 +0100
committerGard Spreemann <gspr@nonempty.org>2020-12-22 15:39:15 +0100
commit7b1d3e5f0a1a36a469905e0b73d48cfea4d1bd46 (patch)
treee211fcdf8cee8d5841ef0dd7b41a89f542444ff7 /scripts
parent6408c2fc41fa1b04d6abf470bafb9961a28c90cd (diff)
parent8433985051c0fb9758fd8dfe7d19cc8eaca630e1 (diff)
Merge tag '1.5.1' into debian/sid
Diffstat (limited to 'scripts')
-rw-r--r--scripts/benchmark/benchmark.py183
-rw-r--r--scripts/benchmark/benchmark_all.py45
-rw-r--r--scripts/benchmark/plot.py134
-rw-r--r--scripts/benchmark/settings.py402
-rw-r--r--scripts/benchmark/utils.py69
-rwxr-xr-xscripts/database/database.py185
-rw-r--r--scripts/database/database/__init__.py0
-rw-r--r--scripts/database/database/bests.py62
-rw-r--r--scripts/database/database/clblast.py269
-rw-r--r--scripts/database/database/db.py76
-rw-r--r--scripts/database/database/defaults.py240
-rw-r--r--scripts/database/database/io.py113
-rwxr-xr-xscripts/generator/generator.py304
-rw-r--r--scripts/generator/generator/__init__.py0
-rw-r--r--scripts/generator/generator/convert.py84
-rw-r--r--scripts/generator/generator/cpp.py422
-rw-r--r--scripts/generator/generator/datatype.py119
-rw-r--r--scripts/generator/generator/doc.py57
-rw-r--r--scripts/generator/generator/pyclblast.py128
-rw-r--r--scripts/generator/generator/routine.py964
20 files changed, 3856 insertions, 0 deletions
diff --git a/scripts/benchmark/benchmark.py b/scripts/benchmark/benchmark.py
new file mode 100644
index 00000000..0bb37c10
--- /dev/null
+++ b/scripts/benchmark/benchmark.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import argparse
+import json
+import os
+import sys
+
+import settings
+import plot
+import utils
+
+EXPERIMENTS = {
+ "axpy": settings.AXPY,
+ "axpybatched": settings.AXPYBATCHED,
+ "gemv": settings.GEMV,
+ "gemm": settings.GEMM,
+ "gemm_small": settings.GEMM_SMALL,
+ "gemmbatched": settings.GEMMBATCHED,
+ "gemmstridedbatched": settings.GEMMSTRIDEDBATCHED,
+ "symm": settings.SYMM,
+ "syrk": settings.SYRK,
+ "summary": settings.SUMMARY,
+}
+
+COMPARISONS = ["clBLAS", "CPU-BLAS", "cuBLAS"]
+COMPARISON_ARGS = ["-clblas", "-cblas", "-cublas"]
+COMPARISON_IDS = [2, 3, 4]
+
+
+def run_benchmark(name, arguments_list, precision, num_runs, platform, device, comparisons):
+ binary = "./clblast_client_x" + name
+
+ # Loops over sub-benchmarks per benchmark
+ results = []
+ for arguments in arguments_list:
+
+ # Sets the arguments
+ constant_arguments = ["-warm_up", "-q", "-no_abbrv"]
+ common_arguments = ["-precision %d" % precision, "-runs %d" % num_runs]
+ opencl_arguments = ["-platform %d" % platform, "-device %d" % device]
+ comparison_arguments = []
+ for name, arg in zip(COMPARISONS, COMPARISON_ARGS):
+ if name in comparisons:
+ comparison_arguments.append(arg + " 1")
+ else:
+ comparison_arguments.append(arg + " 0")
+ all_arguments = opencl_arguments + common_arguments + constant_arguments + comparison_arguments
+ for name, value in arguments.items():
+ all_arguments.append("-" + name + " " + str(value))
+
+ # Calls the binary and parses the results
+ benchmark_output = utils.run_binary(binary, all_arguments)
+ result = utils.parse_results(benchmark_output)
+
+ # For half-precision: also runs single-precision for comparison
+ if precision == 16:
+ all_arguments = [arg if arg != "-precision 16" else "-precision 32" for arg in all_arguments]
+ benchmark_output = utils.run_binary(binary, all_arguments)
+ result_extra = utils.parse_results(benchmark_output)
+ for index in range(len(min(result, result_extra))):
+ result[index]["GBs_1_FP32"] = result_extra[index]["GBs_1"]
+ result[index]["GFLOPS_1_FP32"] = result_extra[index]["GFLOPS_1"]
+ for id in COMPARISON_IDS:
+ if "GBs_%d" % id in result_extra[index].keys():
+ result[index]["GBs_%d" % id] = result_extra[index]["GBs_%d" % id]
+ result[index]["GFLOPS_%d" % id] = result_extra[index]["GFLOPS_%d" % id]
+
+ results.extend(result)
+ return results
+
+
+def parse_arguments(argv):
+ parser = argparse.ArgumentParser(description="Runs a full benchmark for a specific routine on a specific device")
+ parser.add_argument("-b", "--benchmark", required=True, help="The benchmark to perform (choose from %s)" % sorted(EXPERIMENTS.keys()))
+ parser.add_argument("-c", "--comparisons", default=[], nargs='+', help="The library(s) to compare against (choose from %s)" % COMPARISONS)
+ parser.add_argument("-p", "--platform", required=True, type=int, help="The ID of the OpenCL platform to test on")
+ parser.add_argument("-d", "--device", required=True, type=int, help="The ID of the OpenCL device to test on")
+ parser.add_argument("-n", "--num_runs", type=int, default=None, help="Overrides the default number of benchmark repeats for averaging")
+ parser.add_argument("-x", "--precision", type=int, default=32, help="The precision to test for (choose from 16, 32, 64, 3232, 6464")
+ parser.add_argument("-l", "--load_from_disk", action="store_true", help="Increase verbosity of the script")
+ parser.add_argument("-t", "--plot_title", default="", help="The title for the plots, defaults to benchmark name")
+ parser.add_argument("-z", "--tight_plot", action="store_true", help="Enables tight plot layout for in paper or presentation")
+ parser.add_argument("-o", "--output_folder", default=os.getcwd(), help="Sets the folder for output plots (defaults to current folder)")
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
+ cl_args = parser.parse_args(argv)
+ return vars(cl_args)
+
+
+def benchmark_single(benchmark, comparisons, platform, device, num_runs, precision, load_from_disk,
+ plot_title, tight_plot, output_folder, verbose):
+
+ # Sanity check
+ if not os.path.isdir(output_folder):
+ print("[benchmark] Error: folder '%s' doesn't exist" % output_folder)
+ return
+
+ # The benchmark name and plot title
+ benchmark_name = utils.precision_to_letter(precision) + benchmark.upper()
+ if benchmark.upper() != "SUMMARY":
+ plot_title = benchmark_name if plot_title is "" else benchmark_name + ": " + plot_title
+
+ # Retrieves the comparison settings
+ library_ids = [1]
+ for comparison in comparisons:
+ if comparison not in COMPARISONS:
+ print("[benchmark] Invalid comparison library '%s', choose from %s" % (comparison, COMPARISONS))
+ return
+ library_ids.append(COMPARISON_IDS[COMPARISONS.index(comparison)])
+
+ # Retrieves the benchmark settings
+ if benchmark not in EXPERIMENTS.keys():
+ print("[benchmark] Invalid benchmark '%s', choose from %s" % (benchmark, EXPERIMENTS.keys()))
+ return
+ experiment = EXPERIMENTS[benchmark]
+ benchmarks = experiment["benchmarks"]
+
+ # Either run the benchmarks for this experiment or load old results from disk
+ json_file_name = os.path.join(output_folder, benchmark_name.lower() + "_benchmarks.json")
+ if load_from_disk and os.path.isfile(json_file_name):
+ print("[benchmark] Loading previous benchmark results from '" + json_file_name + "'")
+ with open(json_file_name) as f:
+ results = json.load(f)
+ else:
+
+ # Runs all the individual benchmarks
+ print("[benchmark] Running on platform %d, device %d" % (platform, device))
+ print("[benchmark] Running %d benchmarks for settings '%s'" % (len(benchmarks), benchmark))
+ results = {"label_names": ["CLBlast"] + comparisons, "num_rows": experiment["num_rows"],
+ "num_cols": experiment["num_cols"], "benchmarks": []}
+ for bench in benchmarks:
+ num_runs_benchmark = bench["num_runs"] if num_runs is None else num_runs
+ print("[benchmark] Running benchmark '%s:%s'" % (bench["name"], bench["title"]))
+ result = run_benchmark(bench["name"], bench["arguments"], precision, num_runs_benchmark,
+ platform, device, comparisons)
+ results["benchmarks"].append(result)
+
+ # Stores the results to disk
+ print("[benchmark] Saving benchmark results to '" + json_file_name + "'")
+ with open(json_file_name, "w") as f:
+ json.dump(results, f, sort_keys=True, indent=4)
+
+ # Retrieves the data from the benchmark settings
+ file_name_suffix = "_tight" if tight_plot else ""
+ pdf_file_name = os.path.join(output_folder, benchmark_name.lower() + "_plot" + file_name_suffix + ".pdf")
+ titles = [b["title"] if "BATCHED" in b["name"].upper() else
+ utils.precision_to_letter(precision) + b["name"].upper() + " " + b["title"]
+ for b in benchmarks]
+ x_keys = [b["x_keys"] for b in benchmarks]
+ y_keys = [["%s_%d" % (b["y_key"], i) for i in library_ids] for b in benchmarks]
+ x_labels = [b["x_label"] for b in benchmarks]
+ y_labels = [b["y_label"] for b in benchmarks]
+ label_names = results["label_names"]
+
+ # For half-precision: also adds single-precision results for comparison
+ if precision == 16:
+ label_names[0] += " FP16"
+ for index in range(1, len(label_names)):
+ label_names[index] += " FP32"
+ label_names.append("CLBlast FP32")
+ y_keys = [y_key + [y_key[0] + "_FP32"] for y_key in y_keys]
+
+ # For batched routines: comparison is non-batched
+ if benchmark in ["axpybatched", "gemmbatched", "gemmstridedbatched"]:
+ for index in range(1, len(label_names)):
+ label_names[index] += " (non-batched)"
+
+ # Plots the graphs
+ plot.plot_graphs(results["benchmarks"], pdf_file_name, results["num_rows"], results["num_cols"],
+ x_keys, y_keys, titles, x_labels, y_labels,
+ label_names, plot_title, tight_plot, verbose)
+
+ print("[benchmark] All done")
+
+
+if __name__ == '__main__':
+ parsed_arguments = parse_arguments(sys.argv[1:])
+ benchmark_single(**parsed_arguments)
diff --git a/scripts/benchmark/benchmark_all.py b/scripts/benchmark/benchmark_all.py
new file mode 100644
index 00000000..881d6bc0
--- /dev/null
+++ b/scripts/benchmark/benchmark_all.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import argparse
+import os
+import sys
+
+from benchmark import benchmark_single, COMPARISONS
+
+
+BENCHMARKS = ["axpy", "gemv", "gemm", "summary", "axpybatched", "gemmbatched", "gemmstridedbatched"]
+
+
+def parse_arguments(argv):
+ parser = argparse.ArgumentParser(description="Runs all (main) benchmarks in one go for a given device")
+ parser.add_argument("-c", "--comparisons", default=[], nargs='+', help="The library(s) to compare against (choose from %s)" % COMPARISONS)
+ parser.add_argument("-p", "--platform", required=True, type=int, help="The ID of the OpenCL platform to test on")
+ parser.add_argument("-d", "--device", required=True, type=int, help="The ID of the OpenCL device to test on")
+ parser.add_argument("-x", "--precision", type=int, default=32, help="The precision to test for (choose from 16, 32, 64, 3232, 6464")
+ parser.add_argument("-l", "--load_from_disk", action="store_true", help="Increase verbosity of the script")
+ parser.add_argument("-t", "--plot_title", default="", help="The title for the plots, defaults to benchmark name")
+ parser.add_argument("-o", "--output_folder", default=os.getcwd(), help="Sets the folder for output plots (defaults to current folder)")
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
+ cl_args = parser.parse_args(argv)
+ return vars(cl_args)
+
+
+def benchmark_all(comparisons, platform, device, precision, load_from_disk,
+ plot_title, output_folder, verbose):
+ for bench in BENCHMARKS:
+ from_disk = load_from_disk
+ for tight_plot in [True, False]: # two plots for a single benchmark
+ benchmark_single(bench, comparisons, platform, device, None, precision, from_disk,
+ plot_title, tight_plot, output_folder, verbose)
+ from_disk = True # for the next plot of the same data
+
+
+if __name__ == '__main__':
+ parsed_arguments = parse_arguments(sys.argv[1:])
+ benchmark_all(**parsed_arguments)
diff --git a/scripts/benchmark/plot.py b/scripts/benchmark/plot.py
new file mode 100644
index 00000000..6337b78f
--- /dev/null
+++ b/scripts/benchmark/plot.py
@@ -0,0 +1,134 @@
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import utils
+
+import matplotlib
+matplotlib.use('Agg')
+from matplotlib import rcParams
+import matplotlib.pyplot as plt
+
+# Colors
+BLUEISH = [c / 255.0 for c in [71, 101, 177]] # #4765b1
+REDISH = [c / 255.0 for c in [214, 117, 104]] # #d67568
+PURPLISH = [c / 255.0 for c in [85, 0, 119]] # #550077
+GREEN = [c / 255.0 for c in [144, 224, 98]] # #90e062
+COLORS = [BLUEISH, REDISH, PURPLISH, GREEN]
+MARKERS = ["o-", "x-", ".-"]
+
+
+def plot_graphs(results, file_name, num_rows, num_cols,
+ x_keys, y_keys, titles, x_labels, y_labels,
+ label_names, title, tight_plot, verbose):
+ assert len(results) == num_rows * num_cols
+ assert len(results) != 1
+ assert len(x_keys) == len(results)
+ assert len(y_keys) == len(results)
+ assert len(titles) == len(results)
+ assert len(x_labels) == len(results)
+ assert len(y_labels) == len(results)
+
+ # Tight plot (for in a paper or presentation) or regular (for display on a screen)
+ if tight_plot:
+ plot_size = 5
+ w_space = 0.20
+ h_space = 0.39
+ title_from_top = 0.11
+ legend_from_top = 0.17
+ legend_from_top_per_item = 0.04
+ x_label_from_bottom = 0.09
+ legend_spacing = 0.0
+ font_size = 15
+ font_size_legend = 13
+ font_size_title = font_size
+ bounding_box = "tight"
+ else:
+ plot_size = 8
+ w_space = 0.15
+ h_space = 0.22
+ title_from_top = 0.09
+ legend_from_top = 0.10
+ legend_from_top_per_item = 0.07
+ x_label_from_bottom = 0.06
+ legend_spacing = 0.8
+ font_size = 15
+ font_size_legend = font_size
+ font_size_title = 18
+ bounding_box = None # means not 'tight'
+
+ # Initializes the plot
+ size_x = plot_size * num_cols
+ size_y = plot_size * num_rows
+ rcParams.update({'font.size': font_size})
+ fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(size_x, size_y), facecolor='w', edgecolor='k')
+ fig.text(.5, 0.92, title, horizontalalignment="center", fontsize=font_size_title)
+ plt.subplots_adjust(wspace=w_space, hspace=h_space)
+
+ # Loops over each subplot
+ for row in range(num_rows):
+ for col in range(num_cols):
+ index = row * num_cols + col
+ result = results[index]
+ ax = axes.flat[index]
+ plt.sca(ax)
+ print("[plot] Plotting subplot %d" % index)
+
+ # Sets the x-axis labels
+ x_list = [[r[x_key] for r in result] for x_key in x_keys[index]]
+ x_ticks = [",".join([utils.float_to_kilo_mega(v) for v in values]) for values in zip(*x_list)]
+ x_location = range(len(x_ticks))
+
+ # Optional sparsifying of the labels on the x-axis
+ if tight_plot and len(x_location) > 10:
+ x_ticks = [v if not (i % 2) else "" for i, v in enumerate(x_ticks)]
+
+ # Sets the y-data
+ y_list = [[r[y_key] if y_key in r.keys() else 0 for r in result] for y_key in y_keys[index]]
+ y_max = [max(y) if len(y) else 1 for y in y_list]
+ y_max = max(y_max) if len(y_list) > 0 else 1
+
+ # Sets the axes
+ y_rounding = 10 if y_max < 80 else 50 if y_max < 400 else 200
+ y_axis_limit = (y_max * 1.2) - ((y_max * 1.2) % y_rounding) + y_rounding
+ plt.ylim(ymin=0, ymax=y_axis_limit)
+ plt.xticks(x_location, x_ticks, rotation='vertical')
+
+ # Sets the labels
+ ax.set_title(titles[index], y=1.0 - title_from_top, fontsize=font_size)
+ if col == 0 or y_labels[index] != y_labels[index - 1]:
+ ax.set_ylabel(y_labels[index])
+ ax.set_xlabel(x_labels[index])
+ ax.xaxis.set_label_coords(0.5, x_label_from_bottom)
+
+ # Plots the graph
+ assert len(COLORS) >= len(y_keys[index])
+ assert len(MARKERS) >= len(y_keys[index])
+ assert len(label_names) == len(y_keys[index])
+ for i in range(len(y_keys[index])):
+ color = COLORS[i]
+ marker = MARKERS[i]
+ if label_names[i] in ["CLBlast", "CLBlast FP32"]:
+ color = BLUEISH
+ marker = "o-"
+ elif label_names[i] in ["CLBlast FP16"]:
+ color = PURPLISH
+ marker = ".-"
+ elif label_names[i] in ["clBLAS", "clBLAS FP32", "clBLAS (non-batched)"]:
+ color = REDISH
+ marker = "x-"
+ elif label_names[i] in ["cuBLAS", "cuBLAS (non-batched)"]:
+ color = GREEN
+ marker = ".-"
+ ax.plot(x_location, y_list[i], marker, label=label_names[i], color=color)
+
+ # Sets the legend
+ leg = ax.legend(loc=(0.02, 1.0 - legend_from_top - legend_from_top_per_item * len(y_keys[index])),
+ handletextpad=0.1, labelspacing=legend_spacing, fontsize=font_size_legend)
+ leg.draw_frame(False)
+
+ # Saves the plot to disk
+ print("[benchmark] Saving plot to '" + file_name + "'")
+ fig.savefig(file_name, bbox_inches=bounding_box)
diff --git a/scripts/benchmark/settings.py b/scripts/benchmark/settings.py
new file mode 100644
index 00000000..bf7d3621
--- /dev/null
+++ b/scripts/benchmark/settings.py
@@ -0,0 +1,402 @@
+#!/usr/bin/env python
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import utils
+
+
+AXPY = {
+ "num_rows": 2, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "axpy", "num_runs": 40,
+ "title": "multiples of 256K",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": utils.k(256), "incx": 1, "incy": 1, "step": utils.k(256), "num_steps": 16}],
+ },
+ {
+ "name": "axpy", "num_runs": 40,
+ "title": "multiples of 256K+1",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": utils.k(256) + 1, "incx": 1, "incy": 1, "step": utils.k(256) + 1, "num_steps": 16}],
+ },
+ {
+ "name": "axpy", "num_runs": 40,
+ "title": "around 1M",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": utils.m(1), "incx": 1, "incy": 1, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "axpy", "num_runs": 20,
+ "title": "around 16M",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": utils.m(16), "incx": 1, "incy": 1, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "axpy", "num_runs": 20,
+ "title": "strides n=8M",
+ "x_label": "increments for x,y", "x_keys": ["incx", "incy"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": utils.m(8), "incx": inc_x, "incy": inc_y, "step": 0, "num_steps": 1}
+ for inc_x in [1, 2, 4] for inc_y in [1, 2, 4]],
+ },
+ {
+ "name": "axpy", "num_runs": 40,
+ "title": "powers of 2",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": n, "incx": 1, "incy": 1, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(utils.k(32), utils.m(64))],
+ }
+ ]
+}
+
+AXPYBATCHED = {
+ "num_rows": 1, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "axpybatched", "num_runs": 10,
+ "title": "num AXPYs = 8",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"batch_num": 8, "n": n, "incx": 1, "incy": 1, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(utils.k(8), utils.m(4))],
+ },
+ {
+ "name": "axpybatched", "num_runs": 5,
+ "title": "num AXPYs = 64",
+ "x_label": "sizes (n)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"batch_num": 64, "n": n, "incx": 1, "incy": 1, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(utils.k(8), utils.m(4))],
+ },
+ {
+ "name": "axpybatched", "num_runs": 10,
+ "title": "n=512K",
+ "x_label": "num AXPYs", "x_keys": ["batch_num"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"batch_num": b, "n": utils.k(512), "incx": 1, "incy": 1, "step": 1, "num_steps": 1}
+ for b in utils.powers_of_2(1, 256)],
+ }
+ ]
+}
+
+GEMV = {
+ "num_rows": 2, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "gemv", "num_runs": 40,
+ "title": "multiples of 256",
+ "x_label": "sizes (n=m)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 256, "m": 256, "incx": 1, "incy": 1, "layout": 102, "step": 256, "num_steps": 20}],
+ },
+ {
+ "name": "gemv", "num_runs": 40,
+ "title": "multiples of 257",
+ "x_label": "sizes (n=m)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 257, "m": 257, "incx": 1, "incy": 1, "layout": 102, "step": 257, "num_steps": 20}],
+ },
+ {
+ "name": "gemv", "num_runs": 20,
+ "title": "around 4K",
+ "x_label": "sizes (n=m)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 4096, "m": 4096, "incx": 1, "incy": 1, "layout": 102, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "gemv", "num_runs": 40,
+ "title": "multiples of 256 rotated",
+ "x_label": "sizes (n=m)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 256, "m": 256, "incx": 1, "incy": 1, "layout": 101, "step": 256, "num_steps": 20}],
+ },
+ {
+ "name": "gemv", "num_runs": 40,
+ "title": "multiples of 257 rotated",
+ "x_label": "sizes (n=m)", "x_keys": ["n"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 257, "m": 257, "incx": 1, "incy": 1, "layout": 101, "step": 257, "num_steps": 20}],
+ },
+ {
+ "name": "gemv", "num_runs": 20,
+ "title": "strides n=m=4K",
+ "x_label": "increments/strides for x,y", "x_keys": ["incx", "incy"],
+ "y_label": "GB/s (higher is better)", "y_key": "GBs",
+ "arguments": [{"n": 4096, "m": 4096, "incx": inc_x, "incy": inc_y, "layout": 102, "step": 0, "num_steps": 1}
+ for inc_x in [1, 2, 4] for inc_y in [1, 2, 4]],
+ }
+ ]
+}
+
+GEMM = {
+ "num_rows": 2, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "gemm", "num_runs": 20,
+ "title": "multiples of 128",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 128, "n": 128, "k": 128, "layout": 102,
+ "transA": 111, "transB": 111, "step": 128, "num_steps": 20}],
+ },
+ {
+ "name": "gemm", "num_runs": 20,
+ "title": "multiples of 129",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 129, "n": 129, "k": 129, "layout": 102,
+ "transA": 111, "transB": 111, "step": 129, "num_steps": 20}],
+ },
+ {
+ "name": "gemm", "num_runs": 20,
+ "title": "around 512",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 512, "n": 512, "k": 512, "layout": 102,
+ "transA": 111, "transB": 111, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "gemm", "num_runs": 10,
+ "title": "around 2048",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 2048, "n": 2048, "k": 2048, "layout": 102,
+ "transA": 111, "transB": 111, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "gemm", "num_runs": 10,
+ "title": "layouts/transpose",
+ "x_label": "layout, transA, transB", "x_keys": ["layout", "transA", "transB"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 1024, "n": 1024, "k": 1024, "layout": layout,
+ "transA": transA, "transB": transB, "step": 0, "num_steps": 1}
+ for layout in [101, 102] for transA in [111, 112] for transB in [111, 112]],
+ },
+ {
+ "name": "gemm", "num_runs": 10,
+ "title": "powers of 2",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": n, "n": n, "k": n, "layout": 102,
+ "transA": 111, "transB": 111, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(8, utils.k(4))],
+ }
+ ]
+}
+
+GEMM_SMALL = {
+ "num_rows": 2, "num_cols": 1,
+ "benchmarks": [
+ {
+ "name": "gemm", "num_runs": 10,
+ "title": "small matrices in steps of 16",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 128, "n": 128, "k": 128, "layout": 102,
+ "transA": 111, "transB": 111, "step": 16, "num_steps": 57}],
+ },
+ {
+ "name": "gemm", "num_runs": 10,
+ "title": "small matrices in steps of 1",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 128, "n": 128, "k": 128, "layout": 102,
+ "transA": 111, "transB": 111, "step": 1, "num_steps": 385}],
+ },
+
+ ]
+}
+
+GEMMBATCHED = {
+ "num_rows": 1, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "gemmbatched", "num_runs": 20,
+ "title": "num GEMMs = 8",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": 8, "m": 32, "n": 32, "k": 32, "layout": 102,
+ "transA": 111, "transB": 111, "step": 32, "num_steps": 20}],
+ },
+ {
+ "name": "gemmbatched", "num_runs": 10,
+ "title": "num GEMMs = 64",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": 64, "m": 32, "n": 32, "k": 32, "layout": 102,
+ "transA": 111, "transB": 111, "step": 32, "num_steps": 20}],
+ },
+ {
+ "name": "gemmbatched", "num_runs": 10,
+ "title": "m=n=k=128",
+ "x_label": "num GEMMs", "x_keys": ["batch_num"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": b, "m": 128, "n": 128, "k": 128, "layout": 102,
+ "transA": 111, "transB": 111} for b in utils.powers_of_2(1, utils.k(4))],
+ }
+ ]
+}
+
+GEMMSTRIDEDBATCHED = {
+ "num_rows": 1, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "gemmstridedbatched", "num_runs": 20,
+ "title": "num GEMMs = 8",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": 8, "m": 32, "n": 32, "k": 32, "layout": 102,
+ "transA": 111, "transB": 111, "step": 32, "num_steps": 20}],
+ },
+ {
+ "name": "gemmstridedbatched", "num_runs": 10,
+ "title": "num GEMMs = 64",
+ "x_label": "sizes (m=n=k)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": 64, "m": 32, "n": 32, "k": 32, "layout": 102,
+ "transA": 111, "transB": 111, "step": 32, "num_steps": 20}],
+ },
+ {
+ "name": "gemmstridedbatched", "num_runs": 10,
+ "title": "m=n=k=128",
+ "x_label": "num GEMMs", "x_keys": ["batch_num"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"batch_num": b, "m": 128, "n": 128, "k": 128, "layout": 102,
+ "transA": 111, "transB": 111} for b in utils.powers_of_2(1, utils.k(4))],
+ }
+ ]
+}
+
+SYMM = {
+ "num_rows": 2, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "multiples of 128",
+ "x_label": "sizes (m=n)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 128, "n": 128, "layout": 102,
+ "side": 141, "triangle": 121, "step": 128, "num_steps": 20}],
+ },
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "multiples of 129",
+ "x_label": "sizes (m=n)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 129, "n": 129, "layout": 102,
+ "side": 141, "triangle": 121, "step": 129, "num_steps": 20}],
+ },
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "around 512",
+ "x_label": "sizes (m=n)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 512, "n": 512, "layout": 102,
+ "side": 141, "triangle": 121, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "around 2048",
+ "x_label": "sizes (m=n)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 2048, "n": 2048, "layout": 102,
+ "side": 141, "triangle": 121, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "layouts/sides/triangles",
+ "x_label": "layout, side, triangle", "x_keys": ["layout", "side", "triangle"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": 1024, "n": 1024, "layout": layout,
+ "side": side, "triangle": triangle, "step": 0, "num_steps": 1}
+ for layout in [101, 102] for side in [141, 142] for triangle in [121, 122]],
+ },
+ {
+ "name": "symm", "num_runs": 10,
+ "title": "powers of 2",
+ "x_label": "sizes (m=n)", "x_keys": ["m"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"m": n, "n": n, "layout": 102,
+ "side": 141, "triangle": 121, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(8, utils.k(4))],
+ }
+ ]
+}
+
+SYRK = {
+ "num_rows": 2, "num_cols": 3,
+ "benchmarks": [
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "multiples of 128",
+ "x_label": "sizes (n=k)", "x_keys": ["n"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": 128, "k": 128, "layout": 102,
+ "side": 141, "triangle": 121, "step": 128, "num_steps": 20}],
+ },
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "multiples of 129",
+ "x_label": "sizes (n=k)", "x_keys": ["n"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": 129, "k": 129, "layout": 102,
+ "side": 141, "triangle": 121, "step": 129, "num_steps": 20}],
+ },
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "around 512",
+ "x_label": "sizes (n=k)", "x_keys": ["n"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": 512, "k": 512, "layout": 102,
+ "side": 141, "triangle": 121, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "around 2048",
+ "x_label": "sizes (n=k)", "x_keys": ["n"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": 2048, "k": 2048, "layout": 102,
+ "side": 141, "triangle": 121, "step": 1, "num_steps": 16}],
+ },
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "layouts/sides/triangles",
+ "x_label": "layout, triangle, transA", "x_keys": ["layout", "triangle", "transA"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": 1024, "k": 1024, "layout": layout,
+ "triangle": triangle, "transA": transA, "step": 0, "num_steps": 1}
+ for layout in [101, 102] for triangle in [121, 122] for transA in [111, 112]],
+ },
+ {
+ "name": "syrk", "num_runs": 10,
+ "title": "powers of 2",
+ "x_label": "sizes (n=k)", "x_keys": ["n"],
+ "y_label": "GFLOPS (higher is better)", "y_key": "GFLOPS",
+ "arguments": [{"n": n, "k": n, "layout": 102,
+ "side": 141, "triangle": 121, "step": 0, "num_steps": 1}
+ for n in utils.powers_of_2(8, utils.k(4))],
+ }
+ ]
+}
+
+SUMMARY = {
+ "num_rows": 3, "num_cols": 2,
+ "benchmarks": [
+ AXPY["benchmarks"][0],
+ AXPY["benchmarks"][1],
+ GEMV["benchmarks"][0],
+ GEMV["benchmarks"][1],
+ GEMM["benchmarks"][0],
+ GEMM["benchmarks"][1],
+ ]
+}
diff --git a/scripts/benchmark/utils.py b/scripts/benchmark/utils.py
new file mode 100644
index 00000000..11aad805
--- /dev/null
+++ b/scripts/benchmark/utils.py
@@ -0,0 +1,69 @@
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import csv
+import subprocess
+
+
+def k(value):
+ return value * 1024
+
+
+def m(value):
+ return value * 1024 * 1024
+
+
+def float_to_kilo_mega(value):
+ if value % 1024 or value <= 1024:
+ return "%.0f" % value
+ elif value % (1024 * 1024) or value <= (1024 * 1024):
+ return "%.0fK" % (value / 1024.0)
+ else:
+ return "%.0fM" % (value / (1024.0 * 1024.0))
+
+
+def powers_of_2(start, stop):
+ while start <= stop:
+ yield start
+ start *= 2
+
+
+def precision_to_letter(precision):
+ if precision == 16:
+ return "H"
+ elif precision == 32:
+ return "S"
+ elif precision == 64:
+ return "D"
+ elif precision == 3232:
+ return "C"
+ elif precision == 6464:
+ return "Z"
+ else:
+ return "X"
+
+
+def run_binary(command, arguments):
+ full_command = command + " " + " ".join(arguments)
+ print("[benchmark] Calling binary: %s" % str(full_command))
+ try:
+ return subprocess.Popen(full_command, shell=True, stdout=subprocess.PIPE).stdout.read()
+ except OSError as e:
+ print("[benchmark] Error while running the binary, got exception: %s" + str(e))
+ return False
+
+
+def parse_results(csv_data):
+ csv_data = csv_data.split("\n")
+ results = csv.DictReader(csv_data, delimiter=";", skipinitialspace=True)
+ results = [r for r in results]
+ for result in results:
+ for key in result:
+ if "i" in result[key]:
+ continue
+ else:
+ result[key] = float(result[key]) if "." in result[key] else int(result[key])
+ return results
diff --git a/scripts/database/database.py b/scripts/database/database.py
new file mode 100755
index 00000000..6bd52760
--- /dev/null
+++ b/scripts/database/database.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import sys
+import os.path
+import glob
+import argparse
+
+import database.io as io
+import database.db as db
+import database.clblast as clblast
+import database.bests as bests
+import database.defaults as defaults
+
+# Server storing a copy of the database
+DATABASE_SERVER_URL = "https://raw.githubusercontent.com/CNugteren/CLBlast-database/master/database.json"
+
+
+def remove_mismatched_arguments(database):
+ """Checks for tuning results with mis-matched entries and removes them according to user preferences"""
+ kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
+
+ # For Python 2 and 3 compatibility
+ try:
+ user_input = raw_input
+ except NameError:
+ user_input = input
+ pass
+
+ # Check for mis-matched entries
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] WARNING: entries for a single kernel with multiple argument values " +
+ str(kernel_group_name))
+ print("[database] Either quit or remove all but one of the argument combinations below:")
+ for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
+ print("[database] %d: %s" % (index, attribute_group_name))
+ for attribute_group_name, mismatching_entries in group_by_arguments:
+ response = user_input("[database] Remove entries corresponding to %s, [y/n]? " %
+ str(attribute_group_name))
+ if response == "y":
+ for entry in mismatching_entries:
+ database["sections"].remove(entry)
+ print("[database] Removed %d entry/entries" % len(mismatching_entries))
+
+ # Sanity-check: all mis-matched entries should be removed
+ for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
+ group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
+ if len(group_by_arguments) != 1:
+ print("[database] ERROR: entries for a single kernel with multiple argument values " +
+ str(kernel_group_name))
+ assert len(group_by_arguments) == 1
+
+
+def remove_database_entries(database, remove_if_matches_fields):
+ assert len(remove_if_matches_fields.keys()) > 0
+
+ def remove_this_entry(section):
+ for key in remove_if_matches_fields.keys():
+ if section[key] != remove_if_matches_fields[key]:
+ return False
+ return True
+
+ old_length = len(database["sections"])
+ database["sections"] = [x for x in database["sections"] if not remove_this_entry(x)]
+ new_length = len(database["sections"])
+ print("[database] Removed %d entries from the database" % (old_length - new_length))
+
+
+def add_tuning_parameter(database, parameter_name, kernel, value):
+ num_changes = 0
+ for section in database["sections"]:
+ if section["kernel"] == kernel:
+ for result in section["results"]:
+ if parameter_name not in result["parameters"]:
+ result["parameters"][parameter_name] = value
+ section["parameter_names"].append(parameter_name)
+ num_changes += 1
+ print("[database] Made %d addition(s) of %s" % (num_changes, parameter_name))
+
+
+def main(argv):
+
+ # Parses the command-line arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("source_folder", help="The folder with JSON files to parse to add to the database")
+ parser.add_argument("clblast_root", help="Root of the CLBlast sources")
+ parser.add_argument("-r", "--remove_device", type=str, default=None, help="Removes all entries for a specific device")
+ parser.add_argument("--add_tuning_parameter", type=str, default=None, help="Adds this parameter to existing entries")
+ parser.add_argument("--add_tuning_parameter_for_kernel", type=str, default=None, help="Adds the above parameter for this kernel")
+ parser.add_argument("--add_tuning_parameter_value", type=int, default=0, help="Set this value as the default for the above parameter")
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
+ cl_args = parser.parse_args(argv)
+
+ # Parses the path arguments
+ database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.json")
+ database_best_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database_best.json")
+ json_files = os.path.join(cl_args.source_folder, "*.json")
+ cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels")
+
+ # Checks whether the command-line arguments are valid
+ clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation
+ if not os.path.isfile(clblast_header):
+ raise RuntimeError("The path '" + cl_args.clblast_root +
+ "' does not point to the root of the CLBlast library")
+ if len(glob.glob(json_files)) < 1:
+ print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")
+
+ # Downloads the database if a local copy is not present
+ if not os.path.isfile(database_filename):
+ io.download_database(database_filename, DATABASE_SERVER_URL)
+
+ # Loads the database from disk
+ database = io.load_database(database_filename)
+
+ # Loops over all JSON files in the supplied folder
+ for file_json in glob.glob(json_files):
+ sys.stdout.write("[database] Processing '" + file_json + "' ") # No newline printed
+
+ try:
+ # Loads the newly imported data
+ imported_data = io.load_tuning_results(file_json)
+
+ # Adds the new data to the database
+ old_size = db.length(database)
+ database = db.add_section(database, imported_data)
+ new_size = db.length(database)
+ print("with " + str(new_size - old_size) + " new items") # Newline printed here
+
+ except ValueError:
+ print("--- WARNING: invalid file, skipping")
+
+ # Checks for tuning results with mis-matched entries
+ remove_mismatched_arguments(database)
+
+ # Stores the modified database back to disk
+ if len(glob.glob(json_files)) >= 1:
+ io.save_database(database, database_filename)
+
+ # Removes database entries before continuing
+ if cl_args.remove_device is not None:
+ print("[database] Removing all results for device '%s'" % cl_args.remove_device)
+ remove_database_entries(database, {"clblast_device_name": cl_args.remove_device})
+ #, "kernel_family": "xgemm"})
+ io.save_database(database, database_filename)
+
+ # Adds new tuning parameters to existing database entries
+ if cl_args.add_tuning_parameter is not None and\
+ cl_args.add_tuning_parameter_for_kernel is not None:
+ print("[database] Adding tuning parameter: '%s' for kernel '%s' with default %d" %
+ (cl_args.add_tuning_parameter, cl_args.add_tuning_parameter_for_kernel,
+ cl_args.add_tuning_parameter_value))
+ add_tuning_parameter(database, cl_args.add_tuning_parameter,
+ cl_args.add_tuning_parameter_for_kernel,
+ cl_args.add_tuning_parameter_value)
+ io.save_database(database, database_filename)
+
+ # Retrieves the best performing results
+ print("[database] Calculating the best results per device/kernel...")
+ database_best_results = bests.get_best_results(database)
+
+ # Determines the defaults for other vendors and per vendor
+ print("[database] Calculating the default values...")
+ database_defaults = defaults.calculate_defaults(database, cl_args.verbose)
+ database_best_results["sections"].extend(database_defaults["sections"])
+
+ # Optionally outputs the database to disk
+ if cl_args.verbose:
+ io.save_database(database_best_results, database_best_filename)
+
+ # Outputs the database as a C++ database
+ print("[database] Producing a C++ database in '" + cpp_database_path + "'...")
+ clblast.print_cpp_database(database_best_results, cpp_database_path)
+
+ print("[database] All done")
+
+
+if __name__ == '__main__':
+ main(sys.argv[1:])
diff --git a/scripts/database/database/__init__.py b/scripts/database/database/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/scripts/database/database/__init__.py
diff --git a/scripts/database/database/bests.py b/scripts/database/database/bests.py
new file mode 100644
index 00000000..c87b80de
--- /dev/null
+++ b/scripts/database/database/bests.py
@@ -0,0 +1,62 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import sys
+
+import database.clblast as clblast
+
+
+def get_best_results(database):
+ """Retrieves the results with the lowest execution times"""
+ sections_best = []
+ for section in database["sections"]:
+ section_best = {}
+
+ # Stores all the section's meta data
+ for attribute in section.keys():
+ if attribute != "results":
+ section_best[attribute] = section[attribute]
+ if section_best["clblast_device_architecture"] == "" and section_best["clblast_device_vendor"] in clblast.VENDORS_WITH_ARCHITECTURE:
+ section_best["clblast_device_architecture"] = clblast.DEVICE_ARCHITECTURE_DEFAULT
+
+ # Find the best result
+ parameters_best = None
+ time_best = sys.float_info.max
+ for result in section["results"]:
+ if result["time"] < time_best:
+ time_best = result["time"]
+ parameters_best = result["parameters"]
+
+ # Stores the best result
+ section_best["results"] = [{"time": time_best, "parameters": parameters_best}]
+ sections_best.append(section_best)
+
+ return {"sections": sections_best}
+
+
+def get_relative_bests(name, common_results, common_parameters, verbose=False):
+ """Retrieves the parameters with the relative best execution time over different devices"""
+
+ # Helper function
+ def argmin(iterable):
+ return min(enumerate(iterable), key=lambda x: x[1])[0]
+
+ # Computes the sum of the execution times over the different devices
+ performance_sums = []
+ for parameters in common_parameters:
+ performance_sum = sum([r["relative_time"] for r in common_results if r["parameters"] == parameters])
+ performance_sums.append(performance_sum)
+
+ # Retrieves the entry with the lowest time
+ best_index = argmin(performance_sums)
+ best_performance = performance_sums[best_index]
+ best_parameters = common_parameters[best_index]
+
+ # Completed, report and return the results
+ if verbose:
+ print("[database] " + str(name) + " with performance " + str(best_performance))
+ return best_parameters
diff --git a/scripts/database/database/clblast.py b/scripts/database/database/clblast.py
new file mode 100644
index 00000000..ce76f305
--- /dev/null
+++ b/scripts/database/database/clblast.py
@@ -0,0 +1,269 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import os
+
+# Type settings (also change in database_structure.hpp)
+STRING_LENGTH = 50
+PARAMETERS_LENGTH = 16
+
+# Constants from the C++ code
+VENDOR_DEFAULT = "default"
+DEVICE_TYPE_DEFAULT = "All"
+DEVICE_NAME_DEFAULT = "default"
+DEVICE_NAME_DEFAULT_CONSTANT = "kDeviceNameDefault "
+DEVICE_ARCHITECTURE_DEFAULT = "default"
+
+# List of attributes
+DEVICE_TYPE_ATTRIBUTES = ["clblast_device_vendor", "clblast_device_type"]
+DEVICE_ATTRIBUTES = ["clblast_device_name", "clblast_device_architecture",
+ "device_core_clock", "device_compute_units"]
+KERNEL_ATTRIBUTES = ["precision", "kernel_family"]
+ARGUMENT_ATTRIBUTES = ["arg_m", "arg_n", "arg_k", "arg_alpha", "arg_beta",
+ "arg_from", "arg_to", "arg_step",
+ "arg_channels", "arg_height", "arg_width", "arg_kernel_h", "arg_kernel_w",
+ "arg_num_kernels", "arg_batch_count"]
+ATTRIBUTES = DEVICE_ATTRIBUTES + DEVICE_TYPE_ATTRIBUTES + KERNEL_ATTRIBUTES + ARGUMENT_ATTRIBUTES
+GROUP_ATTRIBUTES = DEVICE_TYPE_ATTRIBUTES + KERNEL_ATTRIBUTES + ["kernel"] + ARGUMENT_ATTRIBUTES
+
+# Other constants
+VENDORS_WITH_ARCHITECTURE = ["AMD", "NVIDIA"]
+
+def precision_to_string(precision):
+ """Translates a precision number (represented as Python string) into a descriptive string"""
+ if precision == "16":
+ return "Half"
+ elif precision == "32":
+ return "Single"
+ elif precision == "64":
+ return "Double"
+ elif precision == "3232":
+ return "ComplexSingle"
+ elif precision == "6464":
+ return "ComplexDouble"
+ else:
+ raise("Unknown precision: " + precision)
+
+
+def get_cpp_separator():
+ """Retrieves a C++ comment separator"""
+ return "// ================================================================================================="
+
+
+def get_cpp_header(family, precision):
+ """Retrieves the C++ header"""
+ return ("\n" + get_cpp_separator() + """
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. It
+// is auto-generated by the 'scripts/database/database.py' Python script.
+//
+// This file populates the database with best-found tuning parameters for the '%s%s' kernels.
+//\n"""
+ % (family.title(), precision)) + get_cpp_separator() + "\n"
+
+
+def get_cpp_header_namespace():
+ return "\nnamespace clblast {\n" + "namespace database {\n"
+
+
+def get_cpp_footer():
+ """Retrieves the C++ footer"""
+ return "\n} // namespace database\n" + "} // namespace clblast\n"
+
+
+def get_cpp_precision(family, precision):
+ """Retrieves the C++ code for the start of a new precision"""
+ precision_string = precision_to_string(precision)
+ camelcase_name = family.title().replace("_", "")
+ return("\nconst DatabaseEntry %s%s = {\n \"%s\", Precision::k%s"
+ % (camelcase_name, precision_string, camelcase_name, precision_string))
+
+
+def get_cpp_device_vendor(vendor, device_type):
+ """Retrieves the C++ code for the (default) vendor and device type"""
+ if vendor == VENDOR_DEFAULT and device_type == DEVICE_TYPE_DEFAULT:
+ return " { // Default\n kDeviceType%s, \"%s\", {\n" % (device_type, vendor)
+ device_type_caps = device_type[0].upper() + device_type[1:]
+ return " { // %s %ss\n kDeviceType%s, \"%s\", {\n" % (vendor, device_type, device_type_caps, vendor)
+
+
+def get_cpp_family_includes(family, precisions):
+ result = "\n"
+ result += "#include \"database/kernels/%s/%s.hpp\"\n" % (family, family)
+ for precision in precisions:
+ result += "#include \"database/kernels/%s/%s_%s.hpp\"\n" % (family, family, precision)
+ return result
+
+
+def get_hpp_family_includes(family, precisions):
+ result = "\n"
+ result += "#include \"database/database_structure.hpp\"\n"
+ result += "\n"
+ result += "namespace clblast {\n"
+ result += "namespace database {\n"
+ result += "\n"
+ camelcase_name = family.title().replace("_", "")
+ for precision in precisions:
+ precision_string = precision_to_string(precision)
+ result += "extern const DatabaseEntry %s%s;\n" % (camelcase_name, precision_string)
+ result += "\n"
+ result += "} // namespace database\n"
+ result += "} // namespace clblast\n"
+ return result
+
+
+def print_as_name(name):
+ return "Name{\"%-50s\"}" % name.strip()[:STRING_LENGTH]
+
+
+def get_kernel_database_results(kernel_database):
+ """Retrieves the best result from a group of results. Asserts for valid data"""
+ assert len(kernel_database) >= 1
+
+ all_results = [item["results"] for item in kernel_database]
+
+ best_results = all_results[0]
+ for results in all_results:
+
+ # Debugging in case of unexpected results
+ length_assumption = (len(results) == 1)
+ params_assumption = (sorted(results[0]["parameters"]) == sorted(best_results[0]["parameters"]))
+ if not length_assumption or not params_assumption:
+ print("[database] ERROR: Found %d kernel databases, expected 1" % len(kernel_database))
+ all_keys = sorted([key for item in kernel_database for key in item.keys()])
+ missing_keys = set([x for x in all_keys if all_keys.count(x) != len(kernel_database)])
+ print("[database] All keys in databases: %s" % str(set(all_keys)))
+ print("[database] Missing keys in one or more databases: %s" % str(missing_keys))
+ for index, item in enumerate(kernel_database):
+ print("[database] %d:" % index)
+ print(item)
+ assert length_assumption
+ assert params_assumption
+
+ if results[0]["time"] < best_results[0]["time"]:
+ best_results = results
+
+ return best_results
+
+
+def print_cpp_database(database, output_dir):
+ """Outputs the database as C++ code"""
+
+ # Iterates over the kernel families
+ kernel_families = sorted(set([s["kernel_family"] for s in database["sections"]]))
+ for family_name in kernel_families:
+ family_database = [s for s in database["sections"] if s["kernel_family"] == family_name]
+
+ # Goes into a new path for each kernel family
+ family_path = os.path.join(output_dir, family_name)
+
+ # Loops over the different precision (e.g. 16, 32, 3232, 64, 6464)
+ precisions = sorted(set([s["precision"] for s in database["sections"]])) # Based on full database
+ for precision in precisions:
+ precision_database = [s for s in family_database if s["precision"] == precision]
+
+ # Opens a new file for each precision
+ full_path = os.path.join(family_path, family_name + "_" + precision + ".hpp")
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name, precision))
+ f.write(get_cpp_header_namespace())
+ f.write(get_cpp_precision(family_name, precision))
+
+ # In case there is nothing found at all (e.g. 16-bit): continue as if this was a
+ # precision of 32 but with the defaults only
+ if len(precision_database) == 0:
+ print("[database] No results found for %s:%s, retrieving defaults from %s:32" %
+ (family_name, precision, family_name))
+ precision_database = [s for s in family_database if s["precision"] == "32"
+ and s["clblast_device_vendor"] == VENDOR_DEFAULT
+ and s["clblast_device_type"] == DEVICE_TYPE_DEFAULT
+ and s["clblast_device_name"] == DEVICE_NAME_DEFAULT]
+
+ # Discovers the parameters for this kernel
+ parameter_names = []
+ for example_data in precision_database:
+ for example_result in example_data["results"]:
+ parameter_names.extend([str(k) for k in example_result["parameters"].keys()])
+ parameter_names = sorted(list(set(parameter_names)))
+ parameter_names_as_string = ", ".join(['"%s"' % p for p in parameter_names])
+ f.write(", {" + parameter_names_as_string + "}, {\n")
+
+ # Loops over device vendors (e.g. AMD)
+ device_vendors = sorted(set([s["clblast_device_vendor"] for s in precision_database]))
+ for vendor in device_vendors:
+ vendor_database = [s for s in precision_database if s["clblast_device_vendor"] == vendor]
+
+ # Loops over device types (e.g. GPU)
+ device_types = sorted(set([s["clblast_device_type"] for s in vendor_database]))
+ for device_type in device_types:
+ type_database = [s for s in vendor_database if s["clblast_device_type"] == device_type]
+ f.write(get_cpp_device_vendor(vendor, device_type))
+
+ # Loops over every architecture of this vendor-type combination
+ architectures = sorted(set([s["clblast_device_architecture"] for s in type_database]))
+ if vendor in VENDORS_WITH_ARCHITECTURE:
+ architectures = [a for a in architectures if a != ""]
+ for architecture in architectures:
+ architecture_database = [s for s in type_database if s["clblast_device_architecture"] == architecture]
+ architecture_string = DEVICE_ARCHITECTURE_DEFAULT if architecture == "" else architecture
+ f.write(" { \"%s\", {\n" % architecture_string)
+
+ # Loops over every device of this vendor-type combination
+ devices = sorted(set([s["clblast_device_name"] for s in architecture_database]))
+ for device_name in devices:
+ device_database = [s for s in architecture_database if s["clblast_device_name"] == device_name]
+ device_name_as_string = print_as_name(device_name) if device_name != DEVICE_NAME_DEFAULT else DEVICE_NAME_DEFAULT_CONSTANT
+ device_name_cpp = " { %s, Params{ " % device_name_as_string
+ f.write(device_name_cpp)
+
+ # Collects the parameters for this entry
+ parameters = []
+ parameter_index = 0
+ kernels = sorted(set([s["kernel"] for s in device_database]))
+ for kernel in kernels:
+ kernel_database = [s for s in device_database if s["kernel"] == kernel]
+ results = get_kernel_database_results(kernel_database)
+
+ assert len(results) == 1
+ new_parameters = results[0]["parameters"]
+ for parameter_name in sorted(new_parameters):
+ assert parameter_name == parameter_names[parameter_index]
+ parameter_value = new_parameters[parameter_name]
+ parameters.append(str(parameter_value))
+ parameter_index += 1
+
+ # Appends zero's to complete the list
+ assert parameter_index <= PARAMETERS_LENGTH
+ for append_index in range(parameter_index, PARAMETERS_LENGTH):
+ parameters.append("0")
+
+ # Prints the entry
+ f.write(", ".join(parameters))
+ f.write(" } },\n")
+
+ # Prints the architecture footer
+ f.write(" } },\n")
+
+ # Prints the vendor-type combination footer
+ f.write(" }\n },\n")
+
+ # Prints the precision footer
+ f.write(" }\n};\n")
+
+ # Prints the file footer
+ f.write(get_cpp_footer())
+
+ # Creates the combined family sources
+ full_path = os.path.join(family_path, family_name + ".cpp")
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name, ""))
+ f.write(get_cpp_family_includes(family_name, precisions))
+
+ # Creates the combined family includes header
+ full_path = os.path.join(family_path, family_name + ".hpp")
+ with open(full_path, 'w+') as f:
+ f.write(get_cpp_header(family_name, ""))
+ f.write(get_hpp_family_includes(family_name, precisions))
diff --git a/scripts/database/database/db.py b/scripts/database/database/db.py
new file mode 100644
index 00000000..bbe5a247
--- /dev/null
+++ b/scripts/database/database/db.py
@@ -0,0 +1,76 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import itertools
+from operator import itemgetter
+
+
+def length(database):
+ """Computes the total number of tuning entries"""
+ num_tuning_entries = 0
+ for section in database["sections"]:
+ num_tuning_entries += len(section["results"])
+ return num_tuning_entries
+
+
+def add_section(database, new_section):
+ """Adds a new section to the database"""
+ for old_section in database["sections"]:
+
+ # Verify whether the sections match
+ equal = True
+ for attribute in new_section.keys():
+ if attribute != "results":
+ if attribute not in old_section or new_section[attribute] != old_section[attribute]:
+ equal = False
+ break
+
+ # They match: append the new section's results to the corresponding entry in the database and return
+ if equal:
+ old_section["results"] = combine_results(old_section["results"], new_section["results"])
+ return database
+
+ # No match found: append the whole new section to the database
+ database["sections"].append(new_section)
+ return database
+
+
+def combine_results(old_results, new_results):
+ """Adds new results to the results JSON list"""
+ for new_result in new_results:
+ old_results = combine_result(old_results, new_result)
+ return old_results
+
+
+def combine_result(old_results, new_result):
+ """Adds a new result to the results JSON list; filters for duplicate entries and saves the best performing one"""
+
+ # Loops over all existing results to test for already existing entries with these parameters
+ for old_result in old_results:
+
+ # Verify whether the results match
+ equal = new_result["parameters"] == old_result["parameters"]
+
+ # They match: keep only the one with the minimum execution time
+ if equal:
+ old_result["time"] = min(old_result["time"], new_result["time"])
+ return old_results
+
+ # No match found: append a new result
+ old_results.append(new_result)
+ return old_results
+
+
+def group_by(database, attributes):
+ """Returns an list with the name of the group and the corresponding entries in the database"""
+ assert len(database) > 0
+ attributes = [a for a in attributes if a in database[0]]
+ database.sort(key=itemgetter(*attributes))
+ result = []
+ for key, data in itertools.groupby(database, key=itemgetter(*attributes)):
+ result.append((key, list(data)))
+ return result
diff --git a/scripts/database/database/defaults.py b/scripts/database/database/defaults.py
new file mode 100644
index 00000000..a7a98d23
--- /dev/null
+++ b/scripts/database/database/defaults.py
@@ -0,0 +1,240 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import ast
+from collections import defaultdict
+
+import database.bests as bests
+import database.clblast as clblast
+
+
+def set_identifiers(database, group_by_attributes, identifier_name):
+ """Sets a group-identifier based on a given set of attributes. Modifies the database but also returns a list of
+ unique identifiers."""
+ identifiers = []
+ for section in database["sections"]:
+ identifier = []
+ for attribute in group_by_attributes:
+ if attribute in section:
+ identifier.append(section[attribute])
+ section[identifier_name] = ";".join(identifier)
+ identifiers.append(section[identifier_name])
+ return sorted(set(identifiers))
+
+
+def remove_identifiers(database, identifier_name):
+ """Removes an identifier from all sections in the database"""
+ for section in database["sections"]:
+ section.pop(identifier_name, None)
+
+
+def get_groups_by_identifier(database, group_identifiers, identifier_name):
+ """Returns a list of (group, group_identifier) tuples based a previously made grouping"""
+ groups = []
+ for group_identifier in group_identifiers:
+
+ # Get all sections in this group
+ group = []
+ for section in database["sections"]:
+ if section[identifier_name] == group_identifier:
+ group.append(section)
+
+ groups.append((group, group_identifier))
+ return groups
+
+
+def add_default_sections(database, grouping, verbose, values_dict, condition, enable_warning):
+ default_sections = []
+
+ # Groups the database by a certain grouping
+ group_identifiers = set_identifiers(database, grouping, "group_identifier")
+ groups = get_groups_by_identifier(database, group_identifiers, "group_identifier")
+
+ # Loops over all groups
+ for group, group_identifier in groups:
+
+ # Computes the best parameters
+ default_parameters = get_common_best_parameters(group, group_identifier, verbose, enable_warning)
+ assert len(group) > 0
+ if condition(group[0]):
+
+ # Stores all the section's data
+ default_section = {}
+ for attribute in group[0].keys():
+ if attribute != "results" and attribute != "group_identifier":
+ default_section[attribute] = group[0][attribute]
+ default_section["clblast_device_compute_units"] = 0
+ default_section["clblast_device_core_clock"] = 0
+ for key in values_dict.keys():
+ default_section[key] = values_dict[key]
+ default_section["results"] = [{"time": 0.0, "parameters": default_parameters}]
+ default_sections.append(default_section)
+ return default_sections
+
+
+def calculate_defaults(database, verbose):
+ """Sets defaults for devices of the same type/vendor"""
+ default_sections = {"sections": []}
+
+ # Groups the database by kernel, vendor and device architecture (e.g. AMD GPU "Fiji")
+ architecture_group = clblast.GROUP_ATTRIBUTES + ["clblast_device_architecture"]
+ architecture_defaults = add_default_sections(database, architecture_group, verbose,
+ {"clblast_device_name": clblast.DEVICE_NAME_DEFAULT},
+ lambda entry: True, enable_warning=False)
+
+ # Groups the database by kernel, vendor and device type (e.g. AMD GPU)
+ device_defaults = add_default_sections(database, clblast.GROUP_ATTRIBUTES, verbose,
+ {"clblast_device_name": clblast.DEVICE_NAME_DEFAULT,
+ "clblast_device_architecture": clblast.DEVICE_ARCHITECTURE_DEFAULT},
+ lambda entry: entry["clblast_device_architecture"] != "",
+ enable_warning=True)
+ default_sections["sections"].extend(device_defaults)
+
+ # Groups the database by kernel, vendor and device type (e.g. AMD GPU) - but not by arguments!
+ # This is to check for mis-matched arguments in the database. Note: this is not a check on the
+ # architecture defaults
+ attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
+ group_identifiers = set_identifiers(default_sections, attributes, "temp_identifier")
+ groups = get_groups_by_identifier(default_sections, group_identifiers, "temp_identifier")
+ for group, group_identifier in groups:
+ if len(group) != 1:
+ print("[ERROR] Entries for a single kernel with multiple argument values: " + str(group_identifier))
+ assert len(group) == 1
+ remove_identifiers(default_sections, "temp_identifier")
+
+ # Adds the architecture defaults only after running the above check
+ default_sections["sections"].extend(architecture_defaults)
+
+ # Groups the database by kernel only
+ group_identifiers = set_identifiers(database, clblast.KERNEL_ATTRIBUTES + ["kernel"], "group_identifier")
+ groups = get_groups_by_identifier(database, group_identifiers, "group_identifier")
+
+ # Loops over all groups
+ for group, group_identifier in groups:
+
+ # Computes the best parameters
+ default_parameters = get_common_best_parameters(group, group_identifier, verbose,
+ enable_warning=True)
+
+ # Stores all the section's data
+ assert len(group) > 0
+ default_section = {}
+ for attribute in group[0].keys():
+ if attribute != "results" and attribute != "group_identifier":
+ default_section[attribute] = group[0][attribute]
+ default_section["clblast_device_name"] = clblast.DEVICE_NAME_DEFAULT
+ default_section["clblast_device_architecture"] = clblast.DEVICE_ARCHITECTURE_DEFAULT
+ default_section["clblast_device_vendor"] = clblast.VENDOR_DEFAULT
+ default_section["clblast_device_type"] = clblast.DEVICE_TYPE_DEFAULT
+ default_section["clblast_device_compute_units"] = 0
+ default_section["clblast_device_core_clock"] = 0
+ default_section["results"] = [{"time": 0.0, "parameters": default_parameters}]
+ default_sections["sections"].append(default_section)
+
+ # Database with both types of defaults only
+ return default_sections
+
+
+def get_smallest_best_parameters(group):
+ """Sets defaults based on the smallest values of all known entries. The average might be better for performance but
+ some parameters might not be supported on other devices."""
+
+ # Counts the number of devices in this group
+ assert len(group) > 0
+
+ # Find the smallest values of the parameters
+ min_parameters = {}
+ for section in group:
+ assert len(section["results"]) > 0
+ minimum_time = min([result["time"] for result in section["results"]])
+ for result in section["results"]:
+ if result["time"] == minimum_time:
+ for parameter in result["parameters"]:
+ if parameter in min_parameters:
+ min_parameters[parameter] = min(min_parameters[parameter], result["parameters"][parameter])
+ else:
+ min_parameters[parameter] = result["parameters"][parameter]
+
+ return min_parameters
+
+
+def get_parameter_names(section):
+ return [result["parameters"] for result in section["results"]]
+
+
+def get_common_best_parameters(group, group_identifier, verbose, enable_warning):
+ """Sets defaults based on the best values of entries supported by all devices. This might cause a problem in case
+ not every device was tuned with the same parameters. In that case it falls back to the above method to retrieve
+ the smallest best execution time"""
+
+ # Counts the number of devices in this group
+ num_devices = len(group)
+ assert num_devices > 0
+
+ # Inserts the relative execution times into the database
+ for section in group:
+ assert len(section["results"]) > 0
+ minimum_time = min([result["time"] for result in section["results"]])
+ for result in section["results"]:
+ base_line = minimum_time if section["kernel"] != "gemm_kernel_selection" else 1.0
+ result["relative_time"] = result["time"] / base_line
+
+ # Determine which parameters are available for all devices
+ common_parameters = get_parameter_names(group[0]) # Parameters of the first section
+ for i in range(1, num_devices):
+ section_parameters = get_parameter_names(group[i])
+ common_parameters = [p for p in section_parameters if p in common_parameters] # Intersection of the parameters
+
+ # Fall back to another method in case there are no shared entries at all across devices
+ if len(common_parameters) == 0:
+ if verbose:
+ print("[database] No common kernels for: " + str(group_identifier) + " across all %d devices " % num_devices)
+
+ # Computes the amount of devices with shared parameters
+ parameters_count = defaultdict(int)
+ for i in range(0, num_devices):
+ for parameters in get_parameter_names(group[i]):
+ parameters_count[str(parameters)] += 1
+ num_devices_common = max(parameters_count.values())
+
+ # Fall back method in case there are no shared entries at all across devices
+ if num_devices_common == 1:
+ if enable_warning:
+ print("[database] Warning: No common kernels for: " + str(group_identifier) + " at all")
+ smallest_best_parameters = get_smallest_best_parameters(group)
+ if verbose:
+ print("[database] " + str(group_identifier))
+ return smallest_best_parameters
+
+ # Checks if perhaps there are many more shared parameters with a bit fewer devices
+ num_parameters_common = defaultdict(int)
+ for count in parameters_count.values():
+ if count != 1:
+ num_parameters_common[str(count)] += 1
+ if num_parameters_common[str(num_devices_common - 1)] > num_parameters_common[str(num_devices_common)]:
+ num_devices_common -= 1
+ if verbose:
+ print("[database] Found %d common kernels for: " % num_parameters_common[str(num_devices_common)] +
+ str(group_identifier) + " across %d out of %d devices " % (num_devices_common, num_devices))
+
+ # Populates the common parameters
+ for parameters_string in parameters_count.keys():
+ count = parameters_count[parameters_string]
+ if count == num_devices_common:
+ parameters = ast.literal_eval(parameters_string)
+ common_parameters.append(parameters)
+
+ # Removes entries with parameters which are not common
+ common_results = []
+ for section in group:
+ for result in section["results"]:
+ if result["parameters"] in common_parameters:
+ common_results.append(result)
+
+ # Retrieves the entries with the highest relative performance
+ relative_best_parameters = bests.get_relative_bests(group_identifier, common_results, common_parameters, verbose)
+ return relative_best_parameters
diff --git a/scripts/database/database/io.py b/scripts/database/database/io.py
new file mode 100644
index 00000000..b66f18b1
--- /dev/null
+++ b/scripts/database/database/io.py
@@ -0,0 +1,113 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import re
+import json
+
+try:
+ from urllib.request import urlopen # Python 3
+except ImportError:
+ from urllib2 import urlopen # Python 2
+
+
+def download_database(filename, database_url):
+ """Downloads a database and saves it to disk"""
+ print("[database] Downloading database from '" + database_url + "'...")
+ database = urlopen(database_url)
+ with open(filename, "wb") as f:
+ f.write(database.read())
+
+
+def load_database(filename):
+ """Loads a database from disk"""
+ print("[database] Loading database from '" + filename + "'")
+ with open(filename) as f:
+ database = json.load(f)
+ return decompress_database(database)
+
+
+def save_database(database, filename):
+ """Saves a database to disk"""
+ compressed_db = compress_database(database)
+ print("[database] Saving database to '" + filename + "'")
+ with open(filename, "w") as f:
+ json.dump(compressed_db, f, sort_keys=True, indent=2, separators=(',', ': '))
+
+
+def compress_database(database):
+ """Moves certain common fields up in the hierarchy, transforms dicts into lists"""
+ new_sections = []
+ for section in database["sections"]:
+ new_section = {}
+ for field in section:
+ if field == "results":
+ parameter_names = [sorted(result["parameters"].keys()) for result in section["results"]]
+ assert len(list(set([" ".join(p) for p in parameter_names]))) == 1
+ new_section["parameter_names"] = parameter_names[0] # they are all the same
+ new_results = [[",".join([str(result["parameters"][p]) for p in new_section["parameter_names"]]),
+ result["time"]]
+ for result in section["results"]]
+ new_section[field] = new_results
+ elif field != "parameter_names":
+ new_section[field] = section[field]
+ new_sections.append(new_section)
+ return {"sections": new_sections}
+
+
+def decompress_database(database):
+ """Undo the above compression"""
+ for section in database["sections"]:
+ new_results = []
+ for result in section["results"]:
+ parameters = {}
+ for name, value in zip(section["parameter_names"], result[0].split(",")):
+ parameters[name] = int(value)
+ new_result = {
+ "parameters": parameters,
+ "time": result[1]
+ }
+ new_results.append(new_result)
+ section["results"] = new_results
+ return database
+
+
+def load_tuning_results(filename):
+ """Loads JSON data from file and pre-processes it"""
+ with open(filename) as f:
+ json_data = json.load(f)
+
+ # Removes the numbering following the kernel family name
+ json_data["kernel_family"] = re.sub(r'_\d+', '', json_data["kernel_family"])
+
+ # Removes unnecessary data
+ if json_data["best_kernel"]:
+ del json_data["best_kernel"]
+ if json_data["best_time"]:
+ del json_data["best_time"]
+ if json_data["best_parameters"]:
+ del json_data["best_parameters"]
+
+ # Adds the kernel name to the section instead of to the individual results
+ assert len(json_data["results"]) > 0
+ json_data["kernel"] = json_data["results"][0]["kernel"]
+ for result in json_data["results"]:
+ assert json_data["kernel"] == result["kernel"]
+ result.pop("kernel", None)
+
+ # Removes the 'PRECISION' parameter from the individual results: it is redundant
+ for result in json_data["results"]:
+ assert json_data["precision"] == str(result["parameters"]["PRECISION"])
+ result["parameters"].pop("PRECISION", None)
+
+ # Fixes the scalar argument values
+ for value, replacement in zip(["2.00", "2.00+0.50i"], ["2.000000", "2+0.5i"]):
+ for field in ["arg_alpha", "arg_beta"]:
+ if field in json_data.keys() and json_data[field] == value:
+ json_data[field] = replacement
+
+ # All done
+ return json_data
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
new file mode 100755
index 00000000..76c5dc1c
--- /dev/null
+++ b/scripts/generator/generator.py
@@ -0,0 +1,304 @@
+#!/usr/bin/env python
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+#
+# This script automatically generates the bodies of the following files, creating the full CLBlast API interface and
+# implementation (C, C++, and reference BLAS wrappers):
+# clblast.h
+# clblast.cpp
+# clblast_c.h
+# clblast_c.cpp
+# clblast_cuda.h
+# clblast_cuda.cpp
+# clblast_netlib_c.h
+# clblast_netlib_c.cpp
+# wrapper_clblas.h
+# wrapper_cblas.h
+# pyclblast.pyx
+# It also generates the main functions for the correctness and performance tests as found in
+# test/correctness/routines/levelX/xYYYY.cpp
+# test/performance/routines/levelX/xYYYY.cpp
+# It also produces the API documentation found in doc/clblast.md
+
+
+import sys
+import os.path
+import argparse
+
+import generator.cpp as cpp
+import generator.doc as doc
+import generator.pyclblast as pyclblast
+from generator.routine import Routine
+from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU
+
+FILES = [
+ "/include/clblast.h",
+ "/src/clblast.cpp",
+ "/include/clblast_c.h",
+ "/src/clblast_c.cpp",
+ "/test/wrapper_clblas.hpp",
+ "/test/wrapper_cblas.hpp",
+ "/test/wrapper_cublas.hpp",
+ "/include/clblast_netlib_c.h",
+ "/src/clblast_netlib_c.cpp",
+ "/include/clblast_cuda.h",
+ "/src/clblast_cuda.cpp",
+ "/src/pyclblast/src/pyclblast.pyx"
+]
+HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327]
+FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
+HEADER_LINES_DOC = 0
+FOOTER_LINES_DOC = 232
+
+# Different possibilities for requirements
+ald_m = "The value of `a_ld` must be at least `m`."
+ald_n = "The value of `a_ld` must be at least `n`."
+ald_k_one = "The value of `a_ld` must be at least `k + 1`."
+ald_kl_ku_one = "The value of `a_ld` must be at least `kl + ku + 1`."
+ald_transa_m_k = "When `transpose_a == Transpose::kNo`, then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `k`."
+ald_trans_n_k = "When `transpose == Transpose::kNo`, then `a_ld` must be at least `n`, otherwise `a_ld` must be at least `k`."
+ald_side_m_n = "When `side = Side::kLeft` then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `n`."
+bld_m = "The value of `b_ld` must be at least `m`."
+bld_n = "The value of `b_ld` must be at least `n`."
+bld_transb_k_n = "When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`."
+bld_trans_n_k = "When `transpose == Transpose::kNo`, then `b_ld` must be at least `n`, otherwise `b_ld` must be at least `k`."
+cld_m = "The value of `c_ld` must be at least `m`."
+cld_n = "The value of `c_ld` must be at least `n`."
+
+
+# Helper functions to compute vector and matrix sizes
+def size_helper(condition, size_one, size_two, multiplier):
+ length = "(" + condition + ")" + " ? " + size_one + " * " + multiplier + " : " + size_two + " * " + multiplier
+ return length
+
+
+def layout_transpose_condition(prefix):
+ return "(layout == CLBlastLayoutColMajor && " + prefix + "_transpose != CLBlastTransposeNo) || " +\
+ "(layout == CLBlastLayoutRowMajor && " + prefix + "_transpose == CLBlastTransposeNo)"
+
+
+# Different possibilities for the vector and matrix sizes
+xn = "n * x_inc"
+xm = "m * x_inc"
+yn = "n * y_inc"
+ym = "m * y_inc"
+zn = "n * z_inc"
+an = "n * a_ld"
+apn = "((n*(n+1)) / 2)"
+cn = "n * c_ld"
+xmn = size_helper("a_transpose != CLBlastTransposeNo", "m", "n", "x_inc")
+ynm = size_helper("a_transpose != CLBlastTransposeNo", "n", "m", "y_inc")
+amn = size_helper("layout == CLBlastLayoutRowMajor", "m", "n", "a_ld")
+amns = size_helper("side == CLBlastSideLeft", "m", "n", "a_ld")
+amk = size_helper(layout_transpose_condition("a"), "m", "k", "a_ld")
+ank = size_helper(layout_transpose_condition("a"), "n", "k", "a_ld")
+ankab = size_helper(layout_transpose_condition("ab"), "n", "k", "a_ld")
+bkn = size_helper(layout_transpose_condition("b"), "k", "n", "b_ld")
+bnkab = size_helper(layout_transpose_condition("ab"), "n", "k", "b_ld")
+bmn = size_helper("layout == CLBlastLayoutRowMajor", "m", "n", "b_ld")
+bnma = size_helper(layout_transpose_condition("a"), "n", "m", "b_ld")
+cmn = size_helper("layout == CLBlastLayoutRowMajor", "m", "n", "c_ld")
+ammn = size_helper("layout == CLBlastLayoutRowMajor", "m", "((side == CLBlastSideLeft) ? m : n)", "a_ld")
+bmnn = size_helper("layout == CLBlastLayoutRowMajor", "((side == CLBlastSideLeft) ? m : n)", "n", "b_ld")
+im = "height * width * channels"
+col = "height * width * channels"
+imb = "height * width * channels * batch_count"
+kernel = "kernel_h * kernel_w * num_kernels"
+result = "height_out * width_out * num_kernels * batch_count"
+
+
+# ==================================================================================================
+
+# Populates a list of routines
+im2col_constants = ["channels", "height", "width", "kernel_h", "kernel_w", "pad_h", "pad_w", "stride_h", "stride_w", "dilation_h", "dilation_w"]
+convgemm_constants = im2col_constants + ["num_kernels", "batch_count"]
+ROUTINES = [
+[ # Level 1: vector-vector
+ Routine(False, True, 0, False, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], ["1","1","1","1"], [], "", "Generate givens plane rotation", "", []),
+ Routine(False, True, 0, False, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], ["1","1","1","1","1"], [], "", "Generate modified givens plane rotation", "", []),
+ Routine(False, True, 0, False, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], [xn,yn], ["cos","sin"],"", "Apply givens plane rotation", "", []),
+ Routine(False, True, 0, False, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [xn,yn,"1"], [], "", "Apply modified givens plane rotation", "", []),
+ Routine(True, True, 0, False, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [xn,yn], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
+ Routine(True, True, 0, False, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], [xn], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
+ Routine(True, True, 0, False, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
+ Routine(True, True, 0, False, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
+ Routine(True, True, 0, False, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
+ Routine(True, True, 0, False, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
+ Routine(True, True, 0, False, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
+ Routine(True, True, 0, False, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [xn,"1"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
+ Routine(True, True, 0, False, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [xn,"1"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
+ Routine(True, False, 0, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [xn,"1"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
+ Routine(True, True, 0, False, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
+ Routine(True, False, 0, False, "1", "amin", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of absolute minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer.", []),
+ Routine(True, False, 0, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
+ Routine(True, False, 0, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
+],
+[ # Level 2: matrix-vector
+ Routine(True, True, 0, False, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
+ Routine(True, True, 0, False, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
+ Routine(True, True, 0, False, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
+ Routine(True, True, 0, False, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
+ Routine(True, True, 0, False, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, 0, False, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
+ Routine(True, True, 0, False, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
+ Routine(True, True, 0, False, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, 0, False, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
+ Routine(True, True, 0, False, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
+ Routine(True, True, 0, False, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
+ Routine(True, True, 0, False, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a triangular system of equations", "", []),
+ Routine(False, True, 0, False, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
+ Routine(False, True, 0, False, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "", "Solves a packed triangular system of equations", "", []),
+ # Level 2: matrix update
+ Routine(True, True, 0, False, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
+ Routine(True, True, 0, False, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
+ Routine(True, True, 0, False, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
+ Routine(True, True, 0, False, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
+ Routine(True, True, 0, False, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, 0, False, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
+ Routine(True, True, 0, False, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, 0, False, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
+ Routine(True, True, 0, False, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+ Routine(True, True, 0, False, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
+ Routine(True, True, 0, False, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
+],
+[ # Level 3: matrix-matrix
+ Routine(True, True, 0, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
+ Routine(True, True, 0, False, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
+ Routine(True, True, 0, False, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
+ Routine(True, True, 0, False, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
+ Routine(True, True, 0, False, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
+ Routine(True, True, 0, False, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
+ Routine(True, True, 0, False, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
+ Routine(True, True, 0, False, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
+ Routine(True, True, 0, False, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
+],
+[ # Level X: extra routines (not part of BLAS)
+ # Special routines:
+ Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []),
+ Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
+ Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix. Overwrites any existing values in the _col_ buffer", []),
+ Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, ["kernel_mode"], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. Accumulates results on top of the existing values in the _im_ buffer.", []),
+ Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, ["kernel_mode"], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
+ # Batched routines:
+ Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
+ Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
+ Routine(True, True, 2, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "StridedBatched version of GEMM", "As GEMM, but multiple strided operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
+]]
+
+
+def main(argv):
+
+ # Parses the command-line arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("clblast_root", help="Root of the CLBlast sources")
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
+ cl_args = parser.parse_args(argv)
+ library_root = cl_args.clblast_root
+
+ # Checks whether the command-line arguments are valid; exists otherwise
+ for f in FILES:
+ if not os.path.isfile(library_root + f):
+ print("[ERROR] The path '" + library_root + "' does not point to the root of the CLBlast library")
+ sys.exit()
+
+ # Iterates over all regular files to output
+ for i in range(0, len(FILES)):
+
+ # Stores the header and the footer of the original file
+ with open(library_root + FILES[i]) as f:
+ original = f.readlines()
+ file_header = original[:HEADER_LINES[i]]
+ file_footer = original[-FOOTER_LINES[i]:]
+
+ # Re-writes the body of the file
+ with open(library_root + FILES[i], "w") as f:
+ body = ""
+ levels = [1, 2, 3] if (i == 4 or i == 5 or i == 6) else [1, 2, 3, 4]
+ for level in levels:
+ if i not in [11]:
+ body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
+ for routine in ROUTINES[level - 1]:
+ if i == 0:
+ body += cpp.clblast_h(routine)
+ if i == 1:
+ body += cpp.clblast_cc(routine)
+ if i == 2:
+ body += cpp.clblast_c_h(routine)
+ if i == 3:
+ body += cpp.clblast_c_cc(routine)
+ if i == 4:
+ body += cpp.wrapper_clblas(routine)
+ if i == 5:
+ body += cpp.wrapper_cblas(routine)
+ if i == 6:
+ body += cpp.wrapper_cublas(routine)
+ if i == 7:
+ if routine.batched == 0 and routine.name not in ["convgemm"]:
+ body += cpp.clblast_netlib_c_h(routine)
+ if i == 8:
+ if routine.batched == 0 and routine.name not in ["convgemm"]:
+ body += cpp.clblast_netlib_c_cc(routine)
+ if i == 9:
+ body += cpp.clblast_h(routine, cuda=True)
+ if i == 10:
+ body += cpp.clblast_cc(routine, cuda=True)
+ if i == 11:
+ body += pyclblast.generate_pyx(routine)
+ f.write("".join(file_header))
+ f.write(body)
+ f.write("".join(file_footer))
+
+ # Outputs all the test implementations
+ for level in [1, 2, 3, 4]:
+ for routine in ROUTINES[level - 1]:
+ if routine.has_tests:
+ level_string = cpp.LEVEL_NAMES[level - 1]
+ routine_suffix = "level" + level_string + "/x" + routine.lowercase_name() + ".cpp"
+
+ # Correctness tests
+ filename = library_root + "/test/correctness/routines/" + routine_suffix
+ with open(filename, "w") as f:
+ f.write(cpp.HEADER + "\n")
+ f.write(cpp.correctness_test(routine, level_string))
+ f.write(cpp.FOOTER)
+
+ # Performance tests
+ filename = library_root + "/test/performance/routines/" + routine_suffix
+ with open(filename, "w") as f:
+ f.write(cpp.HEADER + "\n")
+ f.write(cpp.performance_test(routine, level_string))
+ f.write(cpp.FOOTER)
+
+ # API documentation
+ filename = cl_args.clblast_root + "/doc/api.md"
+
+ # Stores the header and the footer of the original documentation file
+ with open(filename) as f:
+ original = f.readlines()
+ file_header = original[:HEADER_LINES_DOC]
+ file_footer = original[-FOOTER_LINES_DOC:]
+
+ # Outputs the API documentation
+ with open(filename, "w") as f:
+
+ # Outputs the header
+ f.write("".join(file_header))
+ doc_header = doc.header()
+ f.write(doc_header)
+
+ # Generates the documentation for each routine
+ for level in [1, 2, 3, 4]:
+ for routine in ROUTINES[level - 1]:
+ if routine.implemented:
+ doc_routine = doc.generate(routine)
+ f.write(doc_routine)
+
+ # Outputs the footer
+ f.write("".join(file_footer))
+
+if __name__ == '__main__':
+ main(sys.argv[1:])
diff --git a/scripts/generator/generator/__init__.py b/scripts/generator/generator/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/scripts/generator/generator/__init__.py
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
new file mode 100644
index 00000000..16890d27
--- /dev/null
+++ b/scripts/generator/generator/convert.py
@@ -0,0 +1,84 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+
+def precision_to_full_name(x):
+ """Translates an option name to a CLBlast data-type"""
+ return {
+ 'H': "Half",
+ 'S': "Single",
+ 'D': "Double",
+ 'C': "ComplexSingle",
+ 'Z': "ComplexDouble",
+ }[x]
+
+
+def option_to_clblast(x):
+ """Translates an option name to a CLBlast data-type"""
+ return {
+ 'layout': "Layout",
+ 'a_transpose': "Transpose",
+ 'b_transpose': "Transpose",
+ 'ab_transpose': "Transpose",
+ 'side': "Side",
+ 'triangle': "Triangle",
+ 'diagonal': "Diagonal",
+ 'kernel_mode': "KernelMode",
+ }[x]
+
+
+def option_to_clblas(x):
+ """As above, but for clBLAS data-types"""
+ return {
+ 'layout': "clblasOrder",
+ 'a_transpose': "clblasTranspose",
+ 'b_transpose': "clblasTranspose",
+ 'ab_transpose': "clblasTranspose",
+ 'side': "clblasSide",
+ 'triangle': "clblasUplo",
+ 'diagonal': "clblasDiag",
+ }[x]
+
+
+def option_to_cblas(x):
+ """As above, but for CBLAS data-types"""
+ return {
+ 'layout': "CBLAS_ORDER",
+ 'a_transpose': "CBLAS_TRANSPOSE",
+ 'b_transpose': "CBLAS_TRANSPOSE",
+ 'ab_transpose': "CBLAS_TRANSPOSE",
+ 'side': "CBLAS_SIDE",
+ 'triangle': "CBLAS_UPLO",
+ 'diagonal': "CBLAS_DIAG",
+ }[x]
+
+
+def option_to_cublas(x):
+ """As above, but for clBLAS data-types"""
+ return {
+ 'layout': "Layout",
+ 'a_transpose': "cublasOperation_t",
+ 'b_transpose': "cublasOperation_t",
+ 'ab_transpose': "cublasOperation_t",
+ 'side': "cublasSideMode_t",
+ 'triangle': "cublasFillMode_t",
+ 'diagonal': "cublasDiagType_t",
+ }[x]
+
+
+def option_to_documentation(x):
+ """Translates an option name to a documentation string"""
+ return {
+ 'layout': "Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.",
+ 'a_transpose': "Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'b_transpose': "Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'ab_transpose': "Transposing the packed input matrix AP, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
+ 'side': "The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).",
+ 'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
+ 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
+ 'kernel_mode': "The kernel mode, either `KernelMode::kCrossCorrelation` for the normal mode, or `KernelMode::kConvolution` for the convolution mode that flips a kernel along `h` and `w` axes.",
+ }[x]
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
new file mode 100644
index 00000000..6dc3fc93
--- /dev/null
+++ b/scripts/generator/generator/cpp.py
@@ -0,0 +1,422 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import generator.datatype as datatype
+import generator.convert as convert
+
+
+NL = "\n"
+SEPARATOR = "// ================================================================================================="
+
+# Separators for the BLAS levels
+LEVEL_SEPARATORS = [
+ NL + SEPARATOR + NL + "// BLAS level-1 (vector-vector) routines" + NL + SEPARATOR,
+ NL + SEPARATOR + NL + "// BLAS level-2 (matrix-vector) routines" + NL + SEPARATOR,
+ NL + SEPARATOR + NL + "// BLAS level-3 (matrix-matrix) routines" + NL + SEPARATOR,
+ NL + SEPARATOR + NL + "// Extra non-BLAS routines (level-X)" + NL + SEPARATOR
+]
+
+# Names of the level sub-folders
+LEVEL_NAMES = ["1", "2", "3", "x"]
+
+# Main header/footer for source files
+FOOTER = NL + SEPARATOR + NL
+HEADER = NL + SEPARATOR + """
+// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
+// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
+// width of 100 characters per line.
+//
+// Author(s):
+// Cedric Nugteren <www.cedricnugteren.nl>
+//
+""" + SEPARATOR + NL
+
+
+def clblast_h(routine, cuda=False):
+ """The C++ API header (.h)"""
+ result = NL + "// " + routine.description + ": " + routine.short_names() + NL
+ result += routine.routine_header_cpp(12, " = nullptr", cuda) + ";" + NL
+ return result
+
+
+def clblast_cc(routine, cuda=False):
+ """The C++ API implementation (.cpp)"""
+ indent1 = " " * (15 + routine.length())
+ result = NL + "// " + routine.description + ": " + routine.short_names() + NL
+ if routine.implemented:
+ result += routine.routine_header_cpp(12, "", cuda, implementation=True) + " {" + NL
+ result += " try {" + NL
+ if cuda:
+ result += " const auto context_cpp = Context(context);" + NL
+ result += " const auto device_cpp = Device(device);" + NL
+ result += " auto queue_cpp = Queue(context_cpp, device_cpp);" + NL
+ else:
+ result += " auto queue_cpp = Queue(*queue);" + NL
+ event = "nullptr" if cuda else "event"
+ result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
+ if routine.batched == 1:
+ result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
+ if routine.temp_buffer:
+ null = "0" if cuda else "nullptr"
+ result += " const auto temp_buffer_provided = temp_buffer != " + null + ";\n"
+ result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(" + null + ");\n"
+ result += " routine.Do" + routine.capitalized_name() + "("
+ result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
+ if routine.temp_buffer:
+ result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided"
+ result += ");" + NL
+ result += " return StatusCode::kSuccess;" + NL
+ result += " } catch (...) { return DispatchException(); }" + NL
+ else:
+ result += routine.routine_header_type_cpp(12, cuda) + " {" + NL
+ result += " return StatusCode::kNotImplemented;" + NL
+ result += "}" + NL
+ for flavour in routine.flavours:
+ indent2 = " " * (34 + routine.length() + len(flavour.template))
+ result += "template StatusCode PUBLIC_API " + routine.capitalized_name() + "<" + flavour.template + ">("
+ arguments = routine.arguments_type(flavour)
+ if cuda:
+ arguments = [a.replace("cl_mem", "CUdeviceptr") for a in arguments]
+ result += ("," + NL + indent2).join([a for a in arguments])
+ result += "," + NL + indent2
+ if cuda:
+ result += "const CUcontext, const CUdevice"
+ if routine.temp_buffer:
+ result += ", CUdeviceptr"
+ else:
+ result += "cl_command_queue*, cl_event*"
+ if routine.temp_buffer:
+ result += ", cl_mem"
+ result += ");" + NL
+ return result
+
+
+def clblast_c_h(routine):
+ """The C API header (.h)"""
+ result = NL + "// " + routine.description + ": " + routine.short_names() + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_c(flavour, 38, " PUBLIC_API") + ";" + NL
+ return result
+
+
+def clblast_c_cc(routine):
+ """The C API implementation (.cpp)"""
+ result = NL + "// " + routine.name.upper() + NL
+ for flavour in routine.flavours:
+ template = "<" + flavour.template + ">" if routine.no_scalars() else ""
+ indent = " " * (16 + routine.length() + len(template))
+ result += routine.routine_header_c(flavour, 27, "") + " {" + NL
+ if routine.batched == 1:
+ result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL
+ result += " try {" + NL
+ result += " return static_cast<CLBlastStatusCode>(" + NL
+ result += " clblast::" + routine.capitalized_name() + template + "("
+ result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)])
+ result += "," + NL + indent + "queue, event)" + NL
+ result += " );" + NL
+ result += " } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }" + NL
+ result += "}" + NL
+ return result
+
+
+def clblast_netlib_c_h(routine):
+ """The Netlib CBLAS API header (.h)"""
+ result = NL + "// " + routine.description + ": " + routine.short_names() + NL
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ result += routine.routine_header_netlib(flavour, 20, " PUBLIC_API") + ";" + NL
+ return result
+
+
+def clblast_netlib_c_cc(routine):
+ """The Netlib CBLAS API implementation (.cpp)"""
+ result = NL + "// " + routine.name.upper() + NL
+ for flavour in routine.flavours:
+
+ # There is a version available in CBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ template = "<" + flavour.template + ">" if routine.no_scalars() else ""
+ name_postfix = "_sub" if routine.name in routine.routines_scalar_no_return() else ""
+ indent = " " * (21 + routine.length() + len(template))
+ result += routine.routine_header_netlib(flavour, 9, "") + " {" + NL
+
+ # Initialize OpenCL
+ result += " OPTIONAL_STATIC auto device = get_device();" + NL
+ result += " OPTIONAL_STATIC auto context = clblast::Context(device);" + NL
+ result += " auto queue = clblast::Queue(context, device);" + NL
+
+ # Set alpha and beta
+ result += "".join(" " + s + NL for s in routine.scalar_create_cpp(flavour))
+
+ # Copy data structures to the device
+ for i, name in enumerate(routine.inputs + routine.outputs):
+ result += " " + routine.set_size(name, routine.buffer_sizes[i]) + NL
+ for i, name in enumerate(routine.inputs + routine.outputs):
+ buffer_type = routine.get_buffer_type(name, flavour)
+ result += " " + routine.create_buffer(name, buffer_type) + NL
+ if name in routine.scalar_buffers_second_non_pointer():
+ result += " " + buffer_type + " " + name + "_vec[1]; " + name + "_vec[0] = " + name + ";" + NL
+ for name in routine.inputs + routine.outputs:
+ if name not in routine.scalar_buffers_first():
+ prefix = "" if name in routine.outputs else "const "
+ buffer_type = routine.get_buffer_type(name, flavour)
+ result += " " + routine.write_buffer(name, prefix + buffer_type) + NL
+
+ # The function call
+ result += " auto queue_cl = queue();" + NL
+ result += " auto s = clblast::" + routine.name.capitalize() + template + "("
+ result += ("," + NL + indent).join([a for a in routine.arguments_netlib(flavour, indent)])
+ result += "," + NL + indent + "&queue_cl);" + NL
+
+ # Error handling
+ result += " if (s != clblast::StatusCode::kSuccess) {" + NL
+ result += " throw std::runtime_error(\"CLBlast returned with error code \" + clblast::ToString(s));" + NL
+ result += " }" + NL
+
+ # Copy back and clean-up
+ for name in routine.outputs:
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
+ buffer_type = routine.get_buffer_type(name, flavour)
+ result += " " + buffer_type + " " + name + "[" + name + "_size];" + NL
+ for name in routine.outputs:
+ buffer_type = routine.get_buffer_type(name, flavour)
+ result += " " + routine.read_buffer(name, buffer_type) + NL
+ for name in routine.outputs:
+ if name in routine.scalar_buffers_first() and routine.name not in routine.routines_scalar_no_return():
+ result += " return " + name + "[0]"
+ if flavour.buffer_type in ["float2", "double2"]:
+ if name not in routine.index_buffers():
+ result += ".real()"
+ result += ";" + NL
+ result += "}" + NL
+ return result
+
+
+def wrapper_clblas(routine):
+ """The wrapper to the reference clBLAS routines (for performance/correctness testing)"""
+ result = ""
+ if routine.has_tests:
+ result += NL + "// Forwards the clBLAS calls for %s" % routine.short_names_tested() + NL
+ if routine.no_scalars():
+ result += routine.routine_header_wrapper_clblas(routine.template, True, 21) + ";" + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_wrapper_clblas(flavour, False, 21) + " {" + NL
+
+ # There is a version available in clBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ indent = " " * (17 + routine.length())
+ arguments = routine.arguments_wrapper_clblas(flavour)
+ if routine.scratch:
+ result += " auto queue = Queue(queues[0]);" + NL
+ result += " auto context = queue.GetContext();" + NL
+ result += " auto scratch_buffer = Buffer<" + flavour.template + ">"
+ result += "(context, " + routine.scratch + ");" + NL
+ arguments += ["scratch_buffer()"]
+ result += " return clblas" + flavour.name + routine.name + "("
+ result += ("," + NL + indent).join([a for a in arguments])
+ result += "," + NL + indent + "num_queues, queues, num_wait_events, wait_events, events);"
+
+ # There is no clBLAS available, forward the call to one of the available functions
+ else: # Half-precision
+ indent = " " * (24 + routine.length())
+
+ # Convert to float (note: also integer buffers are stored as half/float)
+ for buf in routine.inputs + routine.outputs:
+ result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL
+
+ # Call the float routine
+ result += " auto status = clblasX" + routine.name + "("
+ result += ("," + NL + indent).join([a for a in routine.arguments_half()])
+ result += "," + NL + indent + "num_queues, queues, num_wait_events, wait_events, events);"
+ result += NL
+
+ # Convert back to half
+ for buf in routine.outputs:
+ result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL
+ result += " return status;"
+
+ # Complete
+ result += NL + "}" + NL
+ return result
+
+
+def wrapper_cblas(routine):
+ """The wrapper to the reference CBLAS routines (for performance/correctness testing)"""
+ result = ""
+ if routine.has_tests:
+ result += NL + "// Forwards the Netlib BLAS calls for %s" % routine.short_names_tested() + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_wrapper_cblas(flavour, 12) + " {" + NL
+
+ # There is a version available in CBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ indent = " " * (10 + routine.length())
+ arguments = routine.arguments_wrapper_cblas(flavour)
+
+ # Complex scalars
+ for scalar in routine.scalars:
+ if flavour.is_complex(scalar):
+ result += " const auto " + scalar + "_array = std::vector<" + flavour.buffer_type[:-1] + ">"
+ result += "{" + scalar + ".real(), " + scalar + ".imag()};" + NL
+
+ # Special case for scalar outputs
+ assignment = ""
+ postfix, postpostfix = "", ""
+ end_of_line = ""
+ extra_argument = ""
+ for output_buffer in routine.outputs:
+ if output_buffer in routine.scalar_buffers_first():
+ if flavour in [datatype.C, datatype.Z]:
+ postfix += "_sub"
+ indent += " "
+ extra_argument += "," + NL + indent
+ extra_argument += "reinterpret_cast<return_pointer_" + flavour.buffer_type[:-1] + ">"
+ extra_argument += "(&" + output_buffer + "_buffer[" + output_buffer + "_offset])"
+ elif output_buffer in routine.index_buffers():
+ assignment = "reinterpret_cast<int*>(&" + output_buffer + "_buffer[0])[" + output_buffer + "_offset] = static_cast<int>("
+ postpostfix = ")"
+ indent += " " * (len(assignment) + 1)
+ else:
+ assignment = output_buffer + "_buffer[" + output_buffer + "_offset]"
+ if flavour.name in ["Sc", "Dz"]:
+ assignment += ".real("
+ end_of_line += ")"
+ else:
+ assignment += " = "
+ indent += " " * len(assignment)
+
+ result += " " + assignment + "cblas_" + flavour.name.lower() + routine.name + postfix + "("
+ result += ("," + NL + indent).join([a for a in arguments])
+ result += extra_argument + end_of_line + ")" + postpostfix + ";" + NL
+
+ # There is no CBLAS available, forward the call to one of the available functions
+ else: # Half-precision
+ indent = " " * (9 + routine.length())
+
+ # Convert to float (note: also integer buffers are stored as half/float)
+ for buf in routine.inputs + routine.outputs:
+ result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer);" + NL
+
+ # Call the float routine
+ result += " cblasX" + routine.name + "("
+ result += ("," + NL + indent).join([a for a in routine.arguments_half()])
+ result += ");" + NL
+
+ # Convert back to half
+ for buf in routine.outputs:
+ result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis);" + NL
+
+ # Complete
+ result += "}" + NL
+ return result
+
+
+def wrapper_cublas(routine):
+ """The wrapper to the reference cuBLAS routines (for performance/correctness testing)"""
+ result = ""
+ if routine.has_tests:
+ result += NL + "// Forwards the cuBLAS calls for %s" % routine.short_names_tested() + NL
+ if routine.no_scalars():
+ result += routine.routine_header_wrapper_cublas(routine.template, True, 23) + ";" + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_wrapper_cublas(flavour, False, 23) + " {" + NL
+
+ # There is a version available in cuBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ indent = " " * (24 + routine.length())
+ arguments = routine.arguments_wrapper_cublas(flavour)
+
+ # Handles row-major
+ if routine.has_layout():
+ result += " if (layout == Layout::kRowMajor) { return CUBLAS_STATUS_NOT_SUPPORTED; }" + NL
+
+ # Complex scalars
+ for scalar in routine.scalars:
+ if flavour.is_complex(scalar):
+ cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
+ result += " " + cuda_complex + " " + scalar + "_cuda;" + NL
+ result += " " + scalar + "_cuda.x = " + scalar + ".real();" + NL
+ result += " " + scalar + "_cuda.y = " + scalar + ".imag();" + NL
+
+ # Calls the cuBLAS routine
+ result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, "
+ result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL
+ result += " cudaDeviceSynchronize();" + NL
+ result += " return status;"
+
+ # There is no cuBLAS available, forward the call to one of the available functions
+ else: # Half-precision
+ result += " return CUBLAS_STATUS_NOT_SUPPORTED;"
+ # indent = " " * (24 + routine.length())
+
+ # # Convert to float (note: also integer buffers are stored as half/float)
+ # for buf in routine.inputs + routine.outputs:
+ # result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL
+
+ # # Call the float routine
+ # result += " return cublasX" + routine.name + "(handle,"
+ # result += ("," + NL + indent).join([a for a in routine.arguments_half()]) + ");" + NL
+ # result += " cudaDeviceSynchronize();" + NL
+ # result += " return status;"
+
+ # # Convert back to half
+ # for buf in routine.outputs:
+ # result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL
+ # result += " return status;"
+
+ # Complete
+ result += NL + "}" + NL
+ return result
+
+
+def performance_test(routine, level_string):
+ """Generates the body of a performance test for a specific routine"""
+ result = ""
+ result += "#include \"test/performance/client.hpp\"" + NL
+ result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL
+ result += "// Main function (not within the clblast namespace)" + NL
+ result += "int main(int argc, char *argv[]) {" + NL
+ result += " const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);" + NL
+ default = convert.precision_to_full_name(routine.flavours[0].precision_name)
+ result += " switch(clblast::GetPrecision(command_line_args, clblast::Precision::k" + default + ")) {" + NL
+ for precision in ["H", "S", "D", "C", "Z"]:
+ result += " case clblast::Precision::k" + convert.precision_to_full_name(precision) + ":"
+ found = False
+ for flavour in routine.flavours:
+ if flavour.precision_name == precision:
+ extra_template_argument = "0, " if routine.name == "gemm" and routine.batched == 0 else ""
+ result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name()
+ result += flavour.test_template(extra_template_argument)
+ result += ">(argc, argv); break;" + NL
+ found = True
+ if not found:
+ result += " throw std::runtime_error(\"Unsupported precision mode\");" + NL
+ result += " }" + NL
+ result += " return 0;" + NL
+ result += "}" + NL
+ return result
+
+
+def correctness_test(routine, level_string):
+ """Generates the body of a correctness test for a specific routine"""
+ result = ""
+ result += "#include \"test/correctness/testblas.hpp\"" + NL
+ result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL
+ result += "// Main function (not within the clblast namespace)" + NL
+ result += "int main(int argc, char *argv[]) {" + NL
+ result += " auto errors = size_t{0};" + NL
+ not_first = "false"
+ extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and routine.batched == 0 else [""]
+ for extra_template_argument in extra_template_arguments:
+ for flavour in routine.flavours:
+ result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name()
+ result += flavour.test_template(extra_template_argument)
+ result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.upper_name() + "\");" + NL
+ not_first = "true"
+ result += " if (errors > 0) { return 1; } else { return 0; }" + NL
+ result += "}" + NL
+ return result
diff --git a/scripts/generator/generator/datatype.py b/scripts/generator/generator/datatype.py
new file mode 100644
index 00000000..f2b1c9e3
--- /dev/null
+++ b/scripts/generator/generator/datatype.py
@@ -0,0 +1,119 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+
+# Short-hands for data-types
+D_HALF = "half"
+D_FLOAT = "float"
+D_DOUBLE = "double"
+D_FLOAT2 = "float2"
+D_DOUBLE2 = "double2"
+D_HALF_OPENCL = "cl_half"
+D_FLOAT2_OPENCL = "cl_float2"
+D_DOUBLE2_OPENCL = "cl_double2"
+
+
+class DataType:
+ """Class holding data-type and precision information"""
+
+ def __init__(self, precision_name, name, template, scalars, buffer_type):
+ self.precision_name = precision_name
+ self.name = name
+ self.template = template
+ self.alpha_cpp = scalars[0]
+ self.beta_cpp = scalars[1]
+ self.alpha_cl = scalars[2]
+ self.beta_cl = scalars[3]
+ self.buffer_type = buffer_type
+
+ def use_alpha(self, postfix=""):
+ """Outputs the name of the data-type (alpha/beta), possibly transforming into the right type"""
+ if self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]:
+ return self.alpha_cpp + "{alpha" + postfix + ".s[0], alpha" + postfix + ".s[1]}"
+ return "alpha" + postfix
+
+ def use_beta(self, postfix=""):
+ """As above, but for beta instead of alpha"""
+ if self.beta_cpp in [D_FLOAT2, D_DOUBLE2]:
+ return self.beta_cpp + "{beta" + postfix + ".s[0], beta" + postfix + ".s[1]}"
+ return "beta" + postfix
+
+ def use_alpha_opencl(self):
+ """As above, but the transformation is in the opposite direction"""
+ if self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]:
+ return self.alpha_cl + "{{alpha.real(), alpha.imag()}}"
+ return "alpha"
+
+ def use_beta_opencl(self):
+ """As above, but for beta instead of alpha"""
+ if self.beta_cpp in [D_FLOAT2, D_DOUBLE2]:
+ return self.beta_cl + "{{beta.real(), beta.imag()}}"
+ return "beta"
+
+ def use_alpha_clblast(self):
+ """Transforms a Netlib CBLAS parameter to CLBlast style"""
+ if self.alpha_cpp == D_FLOAT2:
+ return self.alpha_cpp + "{reinterpret_cast<const float*>(alpha)[0], reinterpret_cast<const float*>(alpha)[1]}"
+ elif self.alpha_cpp == D_DOUBLE2:
+ return self.alpha_cpp + "{reinterpret_cast<const double*>(alpha)[0], reinterpret_cast<const double*>(alpha)[1]}"
+ return "alpha"
+
+ def use_beta_clblast(self):
+ """As above, but for beta instead of alpha"""
+ if self.beta_cpp == D_FLOAT2:
+ return self.beta_cpp + "{reinterpret_cast<const float*>(beta)[0], reinterpret_cast<const float*>(beta)[1]}"
+ elif self.beta_cpp == D_DOUBLE2:
+ return self.beta_cpp + "{reinterpret_cast<const double*>(beta)[0], reinterpret_cast<const double*>(beta)[1]}"
+ return "beta"
+
+ def test_template(self, extra_template_argument):
+ """Returns the template as used in the correctness/performance tests"""
+ buffer_type = "clblast::" + self.buffer_type if self.is_non_standard() else self.buffer_type
+ beta_cpp = "clblast::" + self.beta_cpp if self.beta_cpp in [D_HALF, D_FLOAT2, D_DOUBLE2] else self.beta_cpp
+ if self.buffer_type != self.beta_cpp:
+ return "<" + extra_template_argument + buffer_type + "," + self.beta_cpp + ">, " + buffer_type + ", " + beta_cpp
+ return "<" + extra_template_argument + buffer_type + ">, " + buffer_type + ", " + beta_cpp
+
+ def is_complex(self, scalar):
+ """Current scalar is complex"""
+ return ((scalar == "alpha" and self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]) or
+ (scalar == "beta" and self.beta_cpp in [D_FLOAT2, D_DOUBLE2]))
+
+ def is_non_standard(self):
+ """Current type is of a non-standard type"""
+ return self.buffer_type in [D_HALF, D_FLOAT2, D_DOUBLE2]
+
+ def name_cublas(self):
+ if "i" in self.name:
+ return "I" + self.name[1].lower()
+ return self.name
+
+
+# Regular data-types
+H = DataType("H", "H", D_HALF, [D_HALF] * 2 + [D_HALF_OPENCL] * 2, D_HALF) # half (16)
+S = DataType("S", "S", D_FLOAT, [D_FLOAT] * 4, D_FLOAT) # single (32)
+D = DataType("D", "D", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE) # double (64)
+C = DataType("C", "C", D_FLOAT2, [D_FLOAT2] * 2 + [D_FLOAT2_OPENCL] * 2, D_FLOAT2) # single-complex (3232)
+Z = DataType("Z", "Z", D_DOUBLE2, [D_DOUBLE2] * 2 + [D_DOUBLE2_OPENCL] * 2, D_DOUBLE2) # double-complex (6464)
+
+# Special cases
+Sc = DataType("C", "Sc", D_FLOAT2, [D_FLOAT2] * 4, D_FLOAT2) # As C, but with real output
+Dz = DataType("Z", "Dz", D_DOUBLE2, [D_DOUBLE2] * 4, D_DOUBLE2) # As Z, but with real output
+iH = DataType("H", "iH", D_HALF, [D_HALF] * 4, D_HALF) # As H, but with integer output
+iS = DataType("S", "iS", D_FLOAT, [D_FLOAT] * 4, D_FLOAT) # As S, but with integer output
+iD = DataType("D", "iD", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE) # As D, but with integer output
+iC = DataType("C", "iC", D_FLOAT2, [D_FLOAT2] * 2 + [D_FLOAT2_OPENCL] * 2, D_FLOAT2) # As C, but with integer output
+iZ = DataType("Z", "iZ", D_DOUBLE2, [D_DOUBLE2] * 2 + [D_DOUBLE2_OPENCL] * 2, D_DOUBLE2) # As Z, but with int output
+Css = DataType("C", "C", D_FLOAT, [D_FLOAT, D_FLOAT, D_FLOAT, D_FLOAT], D_FLOAT2) # As C, but with constants from S
+Zdd = DataType("Z", "Z", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE2) # As Z, but with constants from D
+Ccs = DataType("C", "C", D_FLOAT2 + "," + D_FLOAT, [D_FLOAT2, D_FLOAT, D_FLOAT2_OPENCL, D_FLOAT], D_FLOAT2) # As C, but with one constant from S
+Zzd = DataType("Z", "Z", D_DOUBLE2 + "," + D_DOUBLE, [D_DOUBLE2, D_DOUBLE, D_DOUBLE2_OPENCL, D_DOUBLE], D_DOUBLE2) # As Z, but with one constant from D
+
+# C++ template data-types
+T = DataType("T", "typename T", "T", ["T", "T", "T", "T"], "T") # regular routine
+Tc = DataType("Tc", "typename T", "std::complex<T>,T", ["T", "T", "T", "T"], "std::complex<T>") # for herk
+TU = DataType("TU", "typename T, typename U", "T,U", ["T", "U", "T", "U"], "T") # for her2k
diff --git a/scripts/generator/generator/doc.py b/scripts/generator/generator/doc.py
new file mode 100644
index 00000000..9c73ffbc
--- /dev/null
+++ b/scripts/generator/generator/doc.py
@@ -0,0 +1,57 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+NL = "\n"
+
+
+def header():
+ """Generates the header for the API documentation"""
+ result = "CLBlast: API reference" + NL
+ result += "================" + NL + NL + NL
+ return result
+
+
+def generate(routine):
+ """Generates the API documentation for a given routine"""
+ result = ""
+
+ # Routine header
+ result += "x" + routine.upper_name() + ": " + routine.description + NL
+ result += "-------------" + NL + NL
+ result += routine.details + NL + NL
+
+ # Routine API
+ result += "C++ API:" + NL
+ result += "```" + NL
+ result += routine.routine_header_cpp(12, "") + NL
+ result += "```" + NL + NL
+ result += "C API:" + NL
+ result += "```" + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_c(flavour, 27, "") + NL
+ result += "```" + NL + NL
+
+ # Routine arguments
+ result += "Arguments to " + routine.upper_name() + ":" + NL + NL
+ for argument in routine.arguments_doc():
+ result += "* " + argument + NL
+ result += "* `cl_command_queue* queue`: "
+ result += "Pointer to an OpenCL command queue associated with a context and device to execute the routine on." + NL
+ result += "* `cl_event* event`: "
+ result += "Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). "
+ result += "This is an optional argument." + NL + NL
+
+ # Routine requirements
+ if len(routine.requirements_doc()) > 0:
+ result += "Requirements for " + routine.upper_name() + ":" + NL + NL
+ for requirement in routine.requirements_doc():
+ result += "* " + requirement + NL
+ result += NL
+
+ # Routine footer
+ result += NL + NL
+ return result
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
new file mode 100644
index 00000000..47eb2eb4
--- /dev/null
+++ b/scripts/generator/generator/pyclblast.py
@@ -0,0 +1,128 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+import os
+
+
+NL = os.linesep
+SEPARATOR = "####################################################################################################"
+
+
+def to_np_dtype(flavour):
+ return {
+ "S": "float32",
+ "D": "float64",
+ "C": "complex64",
+ "Z": "complex128",
+ "H": "float16",
+ }[flavour.precision_name]
+
+
+def scalar_cython_conversion(scalar, flavour):
+ scalar_type = flavour.alpha_cl if scalar == "alpha" else flavour.beta_cl
+ if scalar_type == "float":
+ return "<cl_float>" + scalar
+ if scalar_type == "double":
+ return "<cl_double>" + scalar
+ if scalar_type in ["cl_float2", "float2"]:
+ return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if scalar_type in ["cl_double2", "double2"]:
+ return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if scalar_type in ["cl_half", "half"]:
+ return "<cl_half>" + scalar
+ raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type))
+
+
+def generate_pyx(routine):
+ result = ""
+ if routine.implemented and routine.plain_name() and routine.level in ["1", "2a", "2b", "3"]:
+ indent = " "
+
+ result += SEPARATOR + NL
+ result += "# " + routine.description + ": " + routine.short_names() + NL
+ result += SEPARATOR + NL
+ result += NL
+
+ # Reference C definition
+ result += "cdef extern from \"clblast_c.h\":" + NL
+ np_dtypes = []
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
+ result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
+ result += ", ".join(routine.arguments_def_c(flavour)) + ","
+ result += "cl_command_queue* queue, cl_event* event)" + NL
+ np_dtypes.append(to_np_dtype(flavour))
+ result += "" + NL
+
+ # Function definition
+ buffers = routine.inputs[:] + routine.outputs[:]
+ result += "def " + routine.plain_name() + "(queue, "
+ result += ", ".join(routine.arguments_python()) + "):" + NL
+
+ # Documentation
+ result += indent + "\"\"\"" + NL
+ result += indent + "x" + routine.upper_name() + ": " + routine.description + NL
+ result += indent + "\"\"\"" + NL
+ result += NL
+
+ # Data types and checks
+ result += indent + "dtype = check_dtype([" + ", ".join(buffers) + "], "
+ result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
+ for buf in buffers:
+ if buf in routine.buffers_vector():
+ result += indent + "check_vector("
+ else:
+ result += indent + "check_matrix("
+ result += buf + ", \"" + buf + "\")" + NL
+ result += NL
+
+ # Buffer transformation
+ for buf in buffers:
+ result += indent + "cdef cl_mem " + buf + "_buffer = <cl_mem><size_t>" + buf + ".base_data.int_ptr" + NL
+ result += NL
+
+ result += indent + "cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
+ result += indent + "cdef cl_event event = NULL" + NL
+
+ for option in routine.options:
+ if option == "a_transpose":
+ result += indent + "a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
+ if option == "b_transpose":
+ result += indent + "b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
+ if option == "ab_transpose":
+ result += indent + "ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
+ if option == "side":
+ result += indent + "side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
+ if option == "triangle":
+ result += indent + "triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
+ if option == "diagonal":
+ result += indent + "diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
+
+ result += "" + NL
+ result += indent + "cdef CLBlastStatusCode err" + NL
+ if_prefix = ""
+ for flavour in routine.flavours:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
+ np_dtype = to_np_dtype(flavour)
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("alpha", scalar_cython_conversion("alpha", flavour)).
+ replace("beta", scalar_cython_conversion("beta", flavour))
+ for x in routine.arguments()]
+ result += indent + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
+ result += indent + indent + "err = CLBlast" + flavour.name + routine.plain_name()
+ result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
+ if_prefix = "el"
+
+ result += indent + "else:" + NL
+ result += indent + indent + "raise ValueError(\"PyCLBlast: Unrecognized data-type '%s'\" % dtype)" + NL
+ result += indent + "if err != CLBlastSuccess:" + NL
+ result += indent + indent + "raise RuntimeError(\"PyCLBlast: 'CLBlastX" + routine.plain_name() + "' failed: %s\" % get_status_message(err))" + NL
+ result += indent + "return cl.Event.from_int_ptr(<size_t>event)" + NL
+ result += NL
+
+ return result
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
new file mode 100644
index 00000000..3b5a6b76
--- /dev/null
+++ b/scripts/generator/generator/routine.py
@@ -0,0 +1,964 @@
+
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
+# PEP8 Python style guide and uses a max-width of 120 characters per line.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+
+from itertools import chain
+
+import generator.convert as convert
+
+
+class Routine:
+ """Class holding routine-specific information (e.g. name, which arguments, which precisions)"""
+ def __init__(self, implemented, has_tests, batched_strided, temp_buffer, level, name, template, flavours, sizes, options,
+ inputs, outputs, buffer_sizes, scalars, scratch,
+ description, details, requirements):
+ self.implemented = implemented
+ self.has_tests = has_tests
+ self.batched = batched_strided
+ self.temp_buffer = temp_buffer
+ self.level = level
+ self.name = name
+ self.template = template
+ self.flavours = flavours
+ self.sizes = sizes
+ self.options = options
+ self.inputs = inputs
+ self.outputs = outputs
+ self.buffer_sizes = buffer_sizes
+ self.scalars = scalars
+ self.scratch = scratch # Scratch buffer (e.g. for xDOT)
+ self.description = description
+ self.details = details
+ self.requirements = requirements
+
+ def lowercase_name(self):
+ postfix = "strided" if self.batched == 2 else ""
+ postfix += "batched" if self.batched != 0 else ""
+ return self.name + postfix
+
+ def plain_name(self):
+ postfix = "Strided" if self.batched == 2 else ""
+ postfix += "Batched" if self.batched != 0 else ""
+ return self.name + postfix
+
+ def capitalized_name(self):
+ postfix = "Strided" if self.batched == 2 else ""
+ postfix += "Batched" if self.batched != 0 else ""
+ return self.name.capitalize() + postfix
+
+ def upper_name(self):
+ postfix = "STRIDED" if self.batched == 2 else ""
+ postfix += "BATCHED" if self.batched != 0 else ""
+ return self.name.upper() + postfix
+
+ def b_star(self):
+ return "*" if self.batched == 1 else ""
+
+ def b_s(self):
+ return "s" if self.batched == 1 else ""
+
+ def batch_count_def(self):
+ return ["const size_t batch_count"] if self.batched != 0 else []
+
+ def batch_count_list(self):
+ return ["batch_count"] if self.batched != 0 else []
+
+ def batch_count_type(self):
+ return ["const size_t"] if self.batched != 0 else []
+
+ def batch_count_doc(self):
+ return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched != 0 else []
+
+ def batched_transform_to_cpp(self):
+ result = []
+ for scalar in self.scalars:
+ result.append("auto " + scalar + "s_cpp = std::vector<T>();")
+ for buffer_name in self.inputs + self.outputs:
+ result.append("auto " + buffer_name + "_offsets_cpp = std::vector<size_t>();")
+ result.append("for (auto batch = size_t{0}; batch < batch_count; ++batch) {")
+ for scalar in self.scalars:
+ result.append(" " + scalar + "s_cpp.push_back(" + scalar + "s[batch]);")
+ for buffer_name in self.inputs + self.outputs:
+ result.append(" " + buffer_name + "_offsets_cpp.push_back(" + buffer_name + "_offsets[batch]);")
+ result.append("}")
+ return result
+
+ def batched_transform_to_complex(self, flavour):
+ result = []
+ for scalar in self.scalars:
+ result.append("auto " + scalar + "s_cpp = std::vector<" + flavour.buffer_type + ">();")
+ result.append("for (auto batch = size_t{0}; batch < batch_count; ++batch) {")
+ for scalar in self.scalars:
+ content = scalar
+ if scalar == "alpha":
+ content = flavour.use_alpha(postfix="s[batch]")
+ elif scalar == "beta":
+ content = flavour.use_beta(postfix="s[batch]")
+ result.append(" " + scalar + "s_cpp.push_back(" + content + ");")
+ result.append("}")
+ return result
+
+ @staticmethod
+ def scalar_buffers_first():
+ """List of scalar buffers"""
+ return ["dot", "nrm2", "asum", "sum", "imax", "imin"]
+
+ @staticmethod
+ def scalar_buffers_second():
+ """List of scalar buffers"""
+ return ["sa", "sb", "sc", "ss", "sd1", "sd2", "sx1", "sy1", "sparam"]
+
+ @staticmethod
+ def scalar_buffers_second_non_pointer():
+ """As above, but these ones are not passed as pointers but as scalars instead"""
+ return ["sy1"]
+
+ @staticmethod
+ def other_scalars():
+ """List of scalars other than alpha and beta"""
+ return ["cos", "sin"]
+
+ @staticmethod
+ def index_buffers():
+ """List of buffers with unsigned int type"""
+ return ["imax", "imin"]
+
+ @staticmethod
+ def postfix(name):
+ """Retrieves the postfix for a buffer"""
+ return "inc" if (name in ["x", "y", "z"]) else "ld"
+
+ @staticmethod
+ def buffers_vector():
+ """Distinguish between vectors and matrices"""
+ return ["x", "y", "z"]
+
+ @staticmethod
+ def buffers_matrix():
+ """Distinguish between vectors and matrices"""
+ return ["a", "b", "c", "ap"]
+
+ @staticmethod
+ def buffers_tensor():
+ """Distinguish between vectors and matrices and tensors"""
+ return ["im", "col", "kernel", "result"]
+
+ @staticmethod
+ def routines_scalar_no_return():
+ return ["dotu", "dotc"]
+
+ @staticmethod
+ def set_size(name, size):
+ """Sets the size of a buffer"""
+ return "const auto " + name + "_size = " + size + ";"
+
+ @staticmethod
+ def create_buffer(name, template):
+ """Creates a new CLCudaAPI buffer"""
+ return "auto " + name + "_buffer = clblast::Buffer<" + template + ">(context, " + name + "_size);"
+
+ def write_buffer(self, name, template):
+ """Writes to a CLCudaAPI buffer"""
+ postfix = ""
+ if name in self.scalar_buffers_second_non_pointer():
+ postfix = "_vec"
+ data_structure = "reinterpret_cast<" + template + "*>(" + name + postfix + ")"
+ return name + "_buffer.Write(queue, " + name + "_size, " + data_structure + ");"
+
+ @staticmethod
+ def read_buffer(name, template):
+ """Reads from a CLCudaAPI buffer"""
+ data_structure = "reinterpret_cast<" + template + "*>(" + name + ")"
+ return name + "_buffer.Read(queue, " + name + "_size, " + data_structure + ");"
+
+ def non_index_inputs(self):
+ """Lists of input/output buffers not index (integer)"""
+ buffers = self.inputs[:] # make a copy
+ for i in self.index_buffers():
+ if i in buffers:
+ buffers.remove(i)
+ return buffers
+
+ def non_index_outputs(self):
+ """Lists of input/output buffers not index (integer)"""
+ buffers = self.outputs[:] # make a copy
+ for i in self.index_buffers():
+ if i in buffers:
+ buffers.remove(i)
+ return buffers
+
+ def buffers_without_ld_inc(self):
+ """List of buffers without 'inc' or 'ld'"""
+ return self.scalar_buffers_first() + self.scalar_buffers_second() + ["ap", "im", "col", "kernel", "result"]
+
+ def get_buffer_type(self, name, flavour):
+ if name in self.index_buffers():
+ return "int"
+ return flavour.buffer_type
+
+ def length(self):
+ """Retrieves the number of characters in the routine's name"""
+ return len(self.capitalized_name())
+
+ def no_scalars(self):
+ """Determines whether or not this routine has scalar arguments (alpha/beta)"""
+ return self.scalars == [] or self.name in ["im2col", "col2im", "convgemm"]
+
+ def has_layout(self):
+ """Determines whether the layout is an argument"""
+ return "layout" in self.options
+
+ def short_names(self):
+ """Returns the upper-case names of these routines (all flavours)"""
+ return "/".join([f.name + self.upper_name() for f in self.flavours])
+
+ def short_names_tested(self):
+ """As above, but excludes some"""
+ names = [f.name + self.upper_name() for f in self.flavours]
+ if "H" + self.upper_name() in names:
+ names.remove("H" + self.upper_name())
+ return "/".join(names)
+
+ def buffers_first(self):
+ """Determines which buffers go first (between alpha and beta) and which ones go after"""
+ if self.level == "2b" or self.name == "had":
+ return ["x", "y"]
+ extra_buffer = "col" if self.name == "col2im" else "im"
+ return ["ap", "a", "b", "x", extra_buffer, "kernel"]
+
+ def buffers_second(self):
+ if self.level == "2b" or self.name == "had":
+ return ["z", "ap", "a", "b", "c"]
+ extra_buffer = "im" if self.name == "col2im" else "col"
+ return ["y", "c", extra_buffer, "result"]
+
+ def buffer(self, name):
+ """Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')"""
+ if name in self.inputs or name in self.outputs:
+ a = [name + "_buffer"]
+ b = [name + "_offset" + self.b_s()]
+ c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += [name + "_stride"]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_bis(self, name):
+ """As above but with a '_bis' suffix for the buffer name"""
+ if name in self.inputs or name in self.outputs:
+ a = [name + "_buffer_bis"]
+ b = [name + "_offset"]
+ c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ if self.batched == 2:
+ c += [name + "_stride"]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_zero_offset(self, name):
+ """As above, but with an offset value of zero"""
+ if name in self.inputs or name in self.outputs:
+ a = [name + "_buffer()"]
+ b = ["0"]
+ c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_def(self, name):
+ """As above but with data-types"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ a = [prefix + "cl_mem " + name + "_buffer"]
+ b = ["const size_t " + self.b_star() + name + "_offset" + self.b_s()]
+ c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ if self.batched == 2:
+ c += ["const size_t " + name + "_stride"]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_def_wrapper_cl(self, name, flavour):
+ """As above but for OpenCL"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"]
+ b = ["const size_t " + name + "_offset"]
+ c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_def_wrapper_cuda(self, name, flavour):
+ """As above but for CUDA"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ a = [prefix + flavour.buffer_type + "* " + name + "_buffer"]
+ b = ["const size_t " + name + "_offset"]
+ c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_def_vector(self, name, flavour):
+ """As above but as vectors"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ a = [prefix + "std::vector<" + flavour.buffer_type + ">& " + name + "_buffer"]
+ b = ["const size_t " + name + "_offset"]
+ c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_def_pointer(self, name, flavour):
+ """As above but as plain C pointer"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ data_type = "void" if flavour.is_non_standard() else flavour.buffer_type
+ pointer = "" if name in self.scalar_buffers_second_non_pointer() else "*"
+ a = [prefix + data_type + pointer + " " + name + ""]
+ c = ["const int " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ return [", ".join(a + c)]
+ return []
+
+ def buffer_clcudaapi(self, name):
+ """As above but with CLCudaAPI buffers"""
+ if name in self.inputs or name in self.outputs:
+ buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type
+ a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"]
+ b = [name + "_offsets_cpp"] if self.batched == 1 else [name + "_offset"]
+ c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += [name + "_stride"]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_wrapper_clblas(self, name):
+ """As above but with a static cast for clBLAS wrapper"""
+ if name in self.inputs or name in self.outputs:
+ a = [name + "_buffer()"]
+ b = [name + "_offset"]
+ c = []
+ if name in ["x", "y", "z"]:
+ c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
+ elif name in ["a", "b", "c"]:
+ c = [name + "_" + self.postfix(name)]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_wrapper_cblas(self, name, flavour):
+ """As above but with a static cast for CBLAS wrapper"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ if name == "sy1":
+ a = [name + "_buffer[" + name + "_offset]"]
+ elif flavour.precision_name in ["C", "Z"]:
+ a = ["reinterpret_cast<" + prefix + flavour.buffer_type[:-1] + "*>" +
+ "(&" + name + "_buffer[" + name + "_offset])"]
+ else:
+ a = ["&" + name + "_buffer[" + name + "_offset]"]
+ c = []
+ if name in ["x", "y", "z", "a", "b", "c"]:
+ c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
+ return [", ".join(a + c)]
+ return []
+
+ def buffer_wrapper_cublas(self, name, flavour):
+ """As above but for cuBLAS the wrapper"""
+ prefix = "const " if name in self.inputs else ""
+ if name in self.inputs or name in self.outputs:
+ if name in self.index_buffers():
+ a = ["reinterpret_cast<int*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif name in self.outputs and flavour.name in ["Sc", "Dz"]:
+ dtype = "float" if flavour.name == "Sc" else "double"
+ a = ["reinterpret_cast<" + dtype + "*>(&" + name + "_buffer[" + name + "_offset])"]
+ elif flavour.precision_name in ["C", "Z"]:
+ cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex"
+ a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" +
+ "(&" + name + "_buffer[" + name + "_offset])"]
+ else:
+ a = ["&" + name + "_buffer[" + name + "_offset]"]
+ c = []
+ if name in ["x", "y", "z"]:
+ c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
+ elif name in ["a", "b", "c"]:
+ c = [name + "_" + self.postfix(name)]
+ result = [", ".join(a + c)]
+ if self.name == "trmm" and name == "a":
+ result *= 2
+ return result
+ return []
+
+ def buffer_type(self, name):
+ """As above, but only data-types"""
+ prefix = "const " if (name in self.inputs) else ""
+ if (name in self.inputs) or (name in self.outputs):
+ a = [prefix + "cl_mem"]
+ b = ["const size_t" + self.b_star()]
+ c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += ["const size_t"]
+ return [", ".join(a + b + c)]
+ return []
+
+ def buffer_doc(self, name):
+ """Retrieves the documentation of the buffers"""
+ prefix = "const " if (name in self.inputs) else ""
+ inout = "input" if (name in self.inputs) else "output"
+ if (name in self.inputs) or (name in self.outputs):
+ math_name = name.upper() + " matrix" if (name in self.buffers_matrix()) else name + " tensor" if (name in self.buffers_tensor()) else name + " vector"
+ inc_ld_description = "Leading dimension " if (name in self.buffers_matrix()) else "Stride/increment "
+ a = ["`" + prefix + "cl_mem " + name + "_buffer`: OpenCL buffer to store the " + inout + " " + math_name + "."]
+ b = ["`const size_t " + self.b_star() + name + "_offset" + self.b_s() + "`: The offset" + self.b_s() + " in elements from the start of the " + inout + " " + math_name + "."]
+ c = []
+ if name not in self.buffers_without_ld_inc():
+ c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " +
+ inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."]
+ if self.batched == 2:
+ c += ["`const size_t " + name + "_stride`: The (fixed) stride between two batches of the " + name.upper() + " matrix."]
+ return a + b + c
+ return []
+
+ def scalar(self, name):
+ """Retrieves the name of a scalar (alpha/beta)"""
+ if name in self.scalars:
+ if self.batched == 1:
+ return [name + "s_cpp"]
+ return [name]
+ return []
+
+ def scalar_cpp(self, name):
+ """As above, but with _cpp as a suffix"""
+ if name in self.scalars:
+ return [name + "_cpp"]
+ return []
+
+ def scalar_half_to_float(self, name):
+ """As above, but converts from float to half"""
+ if name in self.scalars:
+ return ["HalfToFloat(" + name + ")"]
+ return []
+
+ def scalar_use(self, name, flavour):
+ """Retrieves the use of a scalar (alpha/beta)"""
+ if name in self.scalars:
+ if name == "alpha":
+ if self.batched == 1:
+ return ["alphas_cpp.data()"]
+ return [flavour.use_alpha()]
+ elif name == "beta":
+ if self.batched == 1:
+ return ["betas_cpp.data()"]
+ return [flavour.use_beta()]
+ return [name]
+ return []
+
+ def scalar_use_wrapper(self, name, flavour):
+ """As above, but for the clBLAS wrapper"""
+ if name in self.scalars:
+ if name == "alpha":
+ return [flavour.use_alpha_opencl()]
+ elif name == "beta":
+ return [flavour.use_beta_opencl()]
+ return [name]
+ return []
+
+ def scalar_use_wrapper_cblas(self, name, flavour):
+ """As above, but for the CBLAS wrapper"""
+ if name in self.scalars:
+ if flavour.is_complex(name):
+ return [name + "_array.data()"]
+ return [name]
+ return []
+
+ def scalar_use_wrapper_cublas(self, name, flavour):
+ """As above, but for the cuBLAS wrapper"""
+ if name in self.scalars:
+ if flavour.is_complex(name):
+ return ["&" + name + "_cuda"]
+ return ["&" + name]
+ return []
+
+ def scalar_def(self, name, flavour):
+ """Retrieves the definition of a scalar (alpha/beta)"""
+ if name in self.scalars:
+ if name == "alpha":
+ return ["const " + flavour.alpha_cl + " " + self.b_star() + name + self.b_s()]
+ return ["const " + flavour.beta_cl + " " + self.b_star() + name + self.b_s()]
+ return []
+
+ def scalar_def_plain(self, name, flavour):
+ """As above, but without 'cl_' prefix"""
+ if name in self.scalars:
+ if name == "alpha":
+ return ["const " + flavour.alpha_cpp + " " + self.b_star() + name + self.b_s()]
+ return ["const " + flavour.beta_cpp + " " + self.b_star() + name + self.b_s()]
+ return []
+
+ def scalar_def_void(self, name, flavour):
+ """Retrieves the definition of a scalar (alpha/beta) but make it a void pointer in case of non-standard types"""
+ if name in self.scalars:
+ if name == "alpha":
+ data_type = "void*" if flavour.is_complex("alpha") else flavour.alpha_cpp
+ return ["const " + data_type + " " + name]
+ data_type = "void*" if flavour.is_complex("beta") else flavour.beta_cpp
+ return ["const " + data_type + " " + name]
+ return []
+
+ def scalar_type(self, name, flavour):
+ """Retrieves the type of a scalar (alpha/beta)"""
+ if name in self.scalars:
+ if name == "alpha":
+ return ["const " + flavour.alpha_cpp + self.b_star()]
+ return ["const " + flavour.beta_cpp + self.b_star()]
+ return []
+
+ def scalar_doc(self, name):
+ """Retrieves the documentation of a scalar"""
+ if name in self.scalars:
+ if name == "alpha":
+ return ["`const " + self.template.alpha_cpp + " " + self.b_star() + name + self.b_s() + "`: Input scalar constant" + self.b_s() + "."]
+ return ["`const " + self.template.beta_cpp + " " + self.b_star() + name + self.b_s() + "`: Input scalar constant" + self.b_s() + "."]
+ return []
+
+ def scalar_create_cpp(self, flavour):
+ """Creates a C++ version of a scalar based on a void*"""
+ result = []
+ for name in self.scalars:
+ if name == "alpha":
+ result.append("const auto alpha_cpp = " + flavour.use_alpha_clblast() + ";")
+ elif name == "beta":
+ result.append("const auto beta_cpp = " + flavour.use_beta_clblast() + ";")
+ return result
+
+ def sizes_list(self):
+ """Retrieves a list of comma-separated sizes (m, n, k)"""
+ if self.sizes:
+ return [", ".join([s for s in self.sizes])]
+ return []
+
+ def sizes_list_as_int(self):
+ """Retrieves a list of comma-separated sizes (m, n, k) cast to integers"""
+ if self.sizes:
+ return [", ".join(["static_cast<int>(" + s + ")" for s in self.sizes])]
+ return []
+
+ def sizes_def(self):
+ """Retrieves the definition of the sizes (m,n,k)"""
+ if self.sizes:
+ return [", ".join(["const size_t " + s for s in self.sizes])]
+ return []
+
+ def sizes_def_netlib(self):
+ """Retrieves the definition of the sizes (m,n,k) for the CBLAS API"""
+ if self.sizes:
+ return [", ".join(["const int " + s for s in self.sizes])]
+ return []
+
+ def sizes_type(self):
+ """Retrieves the types of the sizes (m,n,k)"""
+ if self.sizes:
+ return [", ".join(["const size_t" for s in self.sizes])]
+ return []
+
+ def sizes_doc(self):
+ """# Retrieves the documentation of the sizes"""
+ if self.sizes:
+ definitions = ["`const size_t " + s + "`: Integer size argument. This value must be positive." for s in self.sizes]
+ return definitions
+ return []
+
+ def options_list(self):
+ """Retrieves a list of options"""
+ if self.options:
+ return [", ".join(self.options)]
+ return []
+
+ def options_list_no_layout(self):
+ """Retrieves a list of options"""
+ options = self.options[:]
+ if "layout" in options:
+ options.remove("layout")
+ if options:
+ return [", ".join(options)]
+ return []
+
+ def options_cast(self, indent):
+ """As above, but now casted to CLBlast data-types"""
+ if self.options:
+ options = ["static_cast<clblast::" + convert.option_to_clblast(o) + ">(" + o + ")" for o in self.options]
+ return [(",\n" + indent).join(options)]
+ return []
+
+ def options_def(self):
+ """Retrieves the definitions of the options (layout, transpose, side, etc.)"""
+ if self.options:
+ definitions = ["const " + convert.option_to_clblast(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_def_c(self):
+ """As above, but now for the C API"""
+ if self.options:
+ definitions = ["const CLBlast" + convert.option_to_clblast(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_def_wrapper_clblas(self):
+ """As above, but now using clBLAS data-types"""
+ if self.options:
+ definitions = ["const " + convert.option_to_clblas(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_def_wrapper_cblas(self):
+ """As above, but now using CBLAS data-types"""
+ if self.options:
+ definitions = ["const " + convert.option_to_cblas(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_def_wrapper_cublas(self):
+ """As above, but now using cuBLAS data-types"""
+ if self.options:
+ definitions = ["const " + convert.option_to_cublas(o) + " " + o for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_type(self):
+ """Retrieves the types of the options (layout, transpose, side, etc.)"""
+ if self.options:
+ definitions = ["const " + convert.option_to_clblast(o) for o in self.options]
+ return [", ".join(definitions)]
+ return []
+
+ def options_doc(self):
+ """Retrieves the documentation of the options"""
+ if self.options:
+ definitions = ["`const " + convert.option_to_clblast(o) + " " + o + "`: " + convert.option_to_documentation(o) for o in self.options]
+ return definitions
+ return []
+
+ def arguments(self):
+ """Retrieves a combination of all the argument names (no types)"""
+ return (self.options_list() + self.sizes_list() +
+ list(chain(*[self.buffer(b) for b in self.scalar_buffers_first()])) +
+ self.scalar("alpha") +
+ list(chain(*[self.buffer(b) for b in self.buffers_first()])) +
+ self.scalar("beta") +
+ list(chain(*[self.buffer(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar(s) for s in self.other_scalars()])))
+
+ def arguments_half(self):
+ """As above, but with conversions from half to float"""
+ return (self.options_list() + self.sizes_list() +
+ list(chain(*[self.buffer_bis(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_half_to_float("alpha") +
+ list(chain(*[self.buffer_bis(b) for b in self.buffers_first()])) +
+ self.scalar_half_to_float("beta") +
+ list(chain(*[self.buffer_bis(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_bis(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar(s) for s in self.other_scalars()])))
+
+ def arguments_clcudaapi(self):
+ """Retrieves a combination of all the argument names, with CLCudaAPI casts"""
+ return (self.options_list() + self.sizes_list() +
+ list(chain(*[self.buffer_clcudaapi(b) for b in self.scalar_buffers_first()])) +
+ self.scalar("alpha") +
+ list(chain(*[self.buffer_clcudaapi(b) for b in self.buffers_first()])) +
+ self.scalar("beta") +
+ list(chain(*[self.buffer_clcudaapi(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_clcudaapi(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar(s) for s in self.other_scalars()])) +
+ self.batch_count_list())
+
+ def arguments_cast(self, flavour, indent):
+ """As above, but with CLBlast casts"""
+ return (self.options_cast(indent) + self.sizes_list() +
+ list(chain(*[self.buffer(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_use("alpha", flavour) +
+ list(chain(*[self.buffer(b) for b in self.buffers_first()])) +
+ self.scalar_use("beta", flavour) +
+ list(chain(*[self.buffer(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use(s, flavour) for s in self.other_scalars()])) +
+ self.batch_count_list())
+
+ def arguments_netlib(self, flavour, indent):
+ """As above, but for the Netlib CBLAS API"""
+ return (self.options_cast(indent) + self.sizes_list() +
+ list(chain(*[self.buffer_zero_offset(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_cpp("alpha") +
+ list(chain(*[self.buffer_zero_offset(b) for b in self.buffers_first()])) +
+ self.scalar_cpp("beta") +
+ list(chain(*[self.buffer_zero_offset(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_zero_offset(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar(s) for s in self.other_scalars()])))
+
+ def arguments_wrapper_clblas(self, flavour):
+ """As above, but for the clBLAS wrapper"""
+ return (self.options_list() + self.sizes_list() +
+ list(chain(*[self.buffer_wrapper_clblas(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_use_wrapper("alpha", flavour) +
+ list(chain(*[self.buffer_wrapper_clblas(b) for b in self.buffers_first()])) +
+ self.scalar_use_wrapper("beta", flavour) +
+ list(chain(*[self.buffer_wrapper_clblas(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_wrapper_clblas(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use_wrapper(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_wrapper_cblas(self, flavour):
+ """As above, but for the CBLAS wrapper"""
+ return (self.options_list() + self.sizes_list_as_int() +
+ self.scalar_use_wrapper_cblas("alpha", flavour) +
+ list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_use_wrapper_cblas("beta", flavour) +
+ list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use_wrapper_cblas(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_wrapper_cublas(self, flavour):
+ """As above, but for the cuBLAS wrapper"""
+ return (self.options_list_no_layout() + self.sizes_list_as_int() +
+ self.scalar_use_wrapper_cublas("alpha", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_use_wrapper_cublas("beta", flavour) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_first()])) +
+ list(chain(*[self.buffer_wrapper_cublas(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_use_wrapper_cublas(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_def(self, flavour):
+ """Retrieves a combination of all the argument definitions"""
+ return (self.options_def() + self.sizes_def() +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_def("alpha", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) +
+ self.scalar_def("beta", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])) +
+ self.batch_count_def())
+
+ def arguments_def_netlib(self, flavour):
+ """As above, but for the Netlib CBLAS API"""
+ result=(self.options_def_c() + self.sizes_def_netlib() +
+ self.scalar_def_void("alpha", flavour) +
+ list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_def_void("beta", flavour) +
+ list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
+ if self.name in self.routines_scalar_no_return():
+ result += list(chain(*[self.buffer_def_pointer(b, flavour) for b in self.scalar_buffers_first()]))
+ result += self.batch_count_def()
+ return result
+
+ def arguments_def_c(self, flavour):
+ """As above, but for the C API"""
+ return (self.options_def_c() + self.sizes_def() +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_def("alpha", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) +
+ self.scalar_def("beta", flavour) +
+ list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])) +
+ self.batch_count_def())
+
+ def arguments_def_wrapper_clblas(self, flavour):
+ """As above, but clBLAS wrapper plain data-types"""
+ return (self.options_def_wrapper_clblas() + self.sizes_def() +
+ list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.scalar_buffers_first()])) +
+ self.scalar_def_plain("alpha", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_def_plain("beta", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_def_wrapper_cblas(self, flavour):
+ """As above, but CBLAS wrapper plain data-types"""
+ return (self.options_def_wrapper_cblas() + self.sizes_def() +
+ list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_first()])) +
+ self.scalar_def_plain("alpha", flavour) +
+ list(chain(*[self.buffer_def_vector(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_def_plain("beta", flavour) +
+ list(chain(*[self.buffer_def_vector(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_def_wrapper_cublas(self, flavour):
+ """As above, but cuBLAS wrapper plain data-types"""
+ return (self.options_def_wrapper_cublas() + self.sizes_def() +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_first()])) +
+ self.scalar_def_plain("alpha", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_first()])) +
+ self.scalar_def_plain("beta", flavour) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_def_wrapper_cuda(b, flavour) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
+
+ def arguments_type(self, flavour):
+ """Retrieves a combination of all the argument types"""
+ return (self.options_type() + self.sizes_type() +
+ list(chain(*[self.buffer_type(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_type("alpha", flavour) +
+ list(chain(*[self.buffer_type(b) for b in self.buffers_first()])) +
+ self.scalar_type("beta", flavour) +
+ list(chain(*[self.buffer_type(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_type(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_type(s, flavour) for s in self.other_scalars()])) +
+ self.batch_count_type())
+
+ def arguments_doc(self):
+ """Retrieves a combination of all the argument types"""
+ return (self.options_doc() + self.sizes_doc() +
+ list(chain(*[self.buffer_doc(b) for b in self.scalar_buffers_first()])) +
+ self.scalar_doc("alpha") +
+ list(chain(*[self.buffer_doc(b) for b in self.buffers_first()])) +
+ self.scalar_doc("beta") +
+ list(chain(*[self.buffer_doc(b) for b in self.buffers_second()])) +
+ list(chain(*[self.buffer_doc(b) for b in self.scalar_buffers_second()])) +
+ list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])) +
+ self.batch_count_doc())
+
+ def arguments_python(self):
+ """Arguments for the Python wrapper pyclblast"""
+ result = list()
+ result.extend(self.sizes)
+ buffers = self.inputs + self.outputs
+ result.extend(buffers[:])
+ for buf in buffers:
+ if buf in self.buffers_matrix():
+ result.append(buf + "_ld")
+ for buf in buffers:
+ if buf in self.buffers_vector():
+ result.append(buf + "_inc = 1")
+ for scalar in self.scalars:
+ default = "1.0" if scalar == "alpha" else "0.0"
+ result.append(scalar + " = " + default)
+ for option in self.options:
+ if option == "a_transpose":
+ result.append("a_transp = False")
+ if option == "b_transpose":
+ result.append("b_transp = False")
+ if option == "ab_transpose":
+ result.append("ab_transp = False")
+ if option == "side":
+ result.append("right_side = False")
+ if option == "triangle":
+ result.append("lower_triangle = False")
+ if option == "diagonal":
+ result.append("unit_diagonal = False")
+ for buf in buffers:
+ result.append(buf + "_offset = 0")
+ return result
+
+ def requirements_doc(self):
+ """Retrieves a list of routine requirements for documentation"""
+ return self.requirements
+
+ def routine_header_cpp(self, spaces, default_event, cuda=False, implementation=False):
+ """Retrieves the C++ templated definition for a routine"""
+ indent = " " * (spaces + self.length())
+ arguments = self.arguments_def(self.template)
+ mem_type = "cl_mem"
+ if cuda:
+ arguments = [a.replace(mem_type, "CUdeviceptr") for a in arguments]
+ mem_type = "CUdeviceptr"
+ result = "template <" + self.template.name + ">\n"
+ result += "StatusCode " + self.capitalized_name() + "("
+ result += (",\n" + indent).join([a for a in arguments])
+ result += ",\n" + indent
+ if cuda:
+ result += "const CUcontext context, const CUdevice device"
+ else:
+ result += "cl_command_queue* queue, cl_event* event" + default_event
+ if self.temp_buffer:
+ result += ",\n" + indent + mem_type + " temp_buffer"
+ if not implementation:
+ result += " = 0" if cuda else " = nullptr"
+ result += ")"
+ return result
+
+ def routine_header_type_cpp(self, spaces, cuda=False):
+ """As above, but now without variable names"""
+ indent = " " * (spaces + self.length())
+ arguments = self.arguments_type(self.template)
+ if cuda:
+ arguments = [a.replace("cl_mem", "CUdeviceptr") for a in arguments]
+ result = "template <" + self.template.name + ">\n"
+ result += "StatusCode " + self.capitalized_name() + "("
+ result += (",\n" + indent).join([a for a in arguments])
+ result += ",\n" + indent
+ if cuda:
+ result += "const CUcontext, const CUdevice"
+ else:
+ result += "cl_command_queue*, cl_event*"
+ result += ")"
+ return result
+
+ def routine_header_c(self, flavour, spaces, extra_qualifier):
+ """As above, but now for C"""
+ indent = " " * (spaces + self.length())
+ result = "CLBlastStatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.plain_name() + "("
+ result += (",\n" + indent).join([a for a in self.arguments_def_c(flavour)])
+ result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)"
+ return result
+
+ def routine_header_netlib(self, flavour, spaces, extra_qualifier):
+ """As above, but now for the original Netlib CBLAS API"""
+ return_type = "void"
+ for output in self.outputs:
+ if output in self.index_buffers():
+ return_type = "int"
+ break
+ if output in self.scalar_buffers_first() and self.name not in self.routines_scalar_no_return():
+ return_type = flavour.buffer_type.replace("2", "")
+ break
+ indent = " " * (spaces + len(return_type) + self.length())
+ routine_name = self.name
+ if self.name in self.routines_scalar_no_return():
+ routine_name += "_sub"
+ indent += " "
+ if self.batched != 0:
+ routine_name += "batched"
+ result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "("
+ result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")"
+ return result
+
+ def routine_header_wrapper_clblas(self, flavour, def_only, spaces):
+ """As above, but now for the clBLAS wrapper"""
+ template = "<" + flavour.template + ">" if self.no_scalars() and not def_only else ""
+ indent = " " * (spaces + self.length() + len(template))
+ result = ""
+ if self.no_scalars():
+ result += "template <"
+ if def_only:
+ result += flavour.name
+ result += ">\n"
+ result += "clblasStatus clblasX" + self.name + template + "("
+ result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_clblas(flavour)])
+ result += ",\n" + indent + "cl_uint num_queues, cl_command_queue *queues"
+ result += ",\n" + indent + "cl_uint num_wait_events, const cl_event *wait_events, cl_event *events)"
+ return result
+
+ def routine_header_wrapper_cblas(self, flavour, spaces):
+ """As above, but now for the CBLAS wrapper"""
+ indent = " " * (spaces + self.length())
+ result = "void cblasX" + self.name + "("
+ result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cblas(flavour)]) + ")"
+ return result
+
+ def routine_header_wrapper_cublas(self, flavour, def_only, spaces):
+ """As above, but now for the cuBLAS wrapper"""
+ template = "<" + flavour.template + ">" if self.no_scalars() and not def_only else ""
+ indent = " " * (spaces + self.length() + len(template))
+ result = ""
+ if self.no_scalars():
+ result += "template <"
+ if def_only:
+ result += flavour.name
+ result += ">\n"
+ result += "cublasStatus_t cublasX" + self.name + template + "(cublasHandle_t handle, "
+ result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cublas(flavour)]) + ")"
+ return result