summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/cpp.py
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-10-25 14:28:52 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-10-25 14:28:52 +0200
commitf96fd372bc3087938572ebc55bd1d8e1b7e6f18a (patch)
tree5a3a20fa4fdefd942bfd4dbd6713f687cfe67d5b /scripts/generator/generator/cpp.py
parent3b65eace0a4a48568353da3a86ac46d9ff1f1ffc (diff)
Added initial version of a Netlib CBLAS implementation. TODO: Set correct buffer sizes
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r--scripts/generator/generator/cpp.py64
1 files changed, 39 insertions, 25 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 61730fdb..23a2207c 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -99,7 +99,8 @@ def clblast_blas_h(routine):
"""The Netlib CBLAS API header (.h)"""
result = NL + "// " + routine.description + ": " + routine.short_names() + NL
for flavour in routine.flavours:
- result += routine.routine_header_netlib(flavour, 24, " PUBLIC_API") + ";" + NL
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ result += routine.routine_header_netlib(flavour, 24, " PUBLIC_API") + ";" + NL
return result
@@ -107,31 +108,44 @@ def clblast_blas_cc(routine):
"""The Netlib CBLAS API implementation (.cpp)"""
result = NL + "// " + routine.name.upper() + NL
for flavour in routine.flavours:
- template = "<" + flavour.template + ">" if routine.no_scalars() else ""
- indent = " " * (26 + routine.length() + len(template))
- result += routine.routine_header_netlib(flavour, 13, "") + " {" + NL
-
- # Initialize OpenCL
- result += " auto platform = Platform(size_t{0});" + NL
- result += " auto device = Device(platform, size_t{0});" + NL
- result += " auto context = Context(device);" + NL
- result += " auto queue = Queue(context, device);" + NL
-
- # Copy data structures to the device
- for name in routine.inputs + routine.outputs:
- result += " " + routine.create_buffer(name, flavour.template, "0") + NL
- for name in routine.inputs + routine.outputs:
- result += " " + routine.write_buffer(name, "0") + NL
-
- # The function call
- result += " auto status = clblast::" + routine.name.capitalize() + template + "("
- result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)])
- result += "," + NL + indent + "queue, event);" + NL
- # Copy back and clean-up
- for name in routine.outputs:
- result += " " + routine.read_buffer(name, "0") + NL
- result += " return;" + NL + "}" + NL
+ # There is a version available in CBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ template = "<" + flavour.template + ">" if routine.no_scalars() else ""
+ indent = " " * (12 + routine.length() + len(template))
+ result += routine.routine_header_netlib(flavour, 13, "") + " {" + NL
+
+ # Initialize OpenCL
+ result += " auto device = get_device();" + NL
+ result += " auto context = Context(device);" + NL
+ result += " auto queue = Queue(context, device);" + NL
+
+ # Set alpha and beta
+ result += "".join(" " + s + NL for s in routine.scalar_create_cpp(flavour))
+
+ # Copy data structures to the device
+ for i, name in enumerate(routine.inputs + routine.outputs):
+ result += " " + routine.set_size(name, routine.buffer_sizes[i]) + NL
+ result += " " + routine.create_buffer(name, flavour.buffer_type) + NL
+ for name in routine.inputs + routine.outputs:
+ prefix = "" if name in routine.outputs else "const "
+ result += " " + routine.write_buffer(name, prefix + flavour.buffer_type) + NL
+
+ # The function call
+ result += " auto queue_cl = queue();" + NL
+ result += " auto s = " + routine.name.capitalize() + template + "("
+ result += ("," + NL + indent).join([a for a in routine.arguments_netlib(flavour, indent)])
+ result += "," + NL + indent + "&queue_cl);" + NL
+
+ # Error handling
+ result += " if (s != StatusCode::kSuccess) {" + NL
+ result += " throw std::runtime_error(\"CLBlast returned with error code \" + ToString(s));" + NL
+ result += " }" + NL
+
+ # Copy back and clean-up
+ for name in routine.outputs:
+ result += " " + routine.read_buffer(name, flavour.buffer_type) + NL
+ result += "}" + NL
return result