diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2016-10-25 14:28:52 +0200 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2016-10-25 14:28:52 +0200 |
commit | f96fd372bc3087938572ebc55bd1d8e1b7e6f18a (patch) | |
tree | 5a3a20fa4fdefd942bfd4dbd6713f687cfe67d5b /scripts/generator/generator/cpp.py | |
parent | 3b65eace0a4a48568353da3a86ac46d9ff1f1ffc (diff) |
Added initial version of a Netlib CBLAS implementation. TODO: Set correct buffer sizes
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 64 |
1 files changed, 39 insertions, 25 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 61730fdb..23a2207c 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -99,7 +99,8 @@ def clblast_blas_h(routine): """The Netlib CBLAS API header (.h)""" result = NL + "// " + routine.description + ": " + routine.short_names() + NL for flavour in routine.flavours: - result += routine.routine_header_netlib(flavour, 24, " PUBLIC_API") + ";" + NL + if flavour.precision_name in ["S", "D", "C", "Z"]: + result += routine.routine_header_netlib(flavour, 24, " PUBLIC_API") + ";" + NL return result @@ -107,31 +108,44 @@ def clblast_blas_cc(routine): """The Netlib CBLAS API implementation (.cpp)""" result = NL + "// " + routine.name.upper() + NL for flavour in routine.flavours: - template = "<" + flavour.template + ">" if routine.no_scalars() else "" - indent = " " * (26 + routine.length() + len(template)) - result += routine.routine_header_netlib(flavour, 13, "") + " {" + NL - - # Initialize OpenCL - result += " auto platform = Platform(size_t{0});" + NL - result += " auto device = Device(platform, size_t{0});" + NL - result += " auto context = Context(device);" + NL - result += " auto queue = Queue(context, device);" + NL - - # Copy data structures to the device - for name in routine.inputs + routine.outputs: - result += " " + routine.create_buffer(name, flavour.template, "0") + NL - for name in routine.inputs + routine.outputs: - result += " " + routine.write_buffer(name, "0") + NL - - # The function call - result += " auto status = clblast::" + routine.name.capitalize() + template + "(" - result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)]) - result += "," + NL + indent + "queue, event);" + NL - # Copy back and clean-up - for name in routine.outputs: - result += " " + routine.read_buffer(name, "0") + NL - result += " return;" + NL + "}" + NL + # There is a version available in CBLAS + if flavour.precision_name in ["S", "D", "C", "Z"]: + template = "<" + flavour.template + ">" if routine.no_scalars() else "" + indent = " " * (12 + routine.length() + len(template)) + result += routine.routine_header_netlib(flavour, 13, "") + " {" + NL + + # Initialize OpenCL + result += " auto device = get_device();" + NL + result += " auto context = Context(device);" + NL + result += " auto queue = Queue(context, device);" + NL + + # Set alpha and beta + result += "".join(" " + s + NL for s in routine.scalar_create_cpp(flavour)) + + # Copy data structures to the device + for i, name in enumerate(routine.inputs + routine.outputs): + result += " " + routine.set_size(name, routine.buffer_sizes[i]) + NL + result += " " + routine.create_buffer(name, flavour.buffer_type) + NL + for name in routine.inputs + routine.outputs: + prefix = "" if name in routine.outputs else "const " + result += " " + routine.write_buffer(name, prefix + flavour.buffer_type) + NL + + # The function call + result += " auto queue_cl = queue();" + NL + result += " auto s = " + routine.name.capitalize() + template + "(" + result += ("," + NL + indent).join([a for a in routine.arguments_netlib(flavour, indent)]) + result += "," + NL + indent + "&queue_cl);" + NL + + # Error handling + result += " if (s != StatusCode::kSuccess) {" + NL + result += " throw std::runtime_error(\"CLBlast returned with error code \" + ToString(s));" + NL + result += " }" + NL + + # Copy back and clean-up + for name in routine.outputs: + result += " " + routine.read_buffer(name, flavour.buffer_type) + NL + result += "}" + NL return result |