summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/cpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r--scripts/generator/generator/cpp.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py
index 03da7985..49240095 100644
--- a/scripts/generator/generator/cpp.py
+++ b/scripts/generator/generator/cpp.py
@@ -290,6 +290,52 @@ def wrapper_cblas(routine):
return result
+def wrapper_cublas(routine):
+ """The wrapper to the reference cuBLAS routines (for performance/correctness testing)"""
+ result = ""
+ if routine.has_tests:
+ result += NL + "// Forwards the cuBLAS calls for %s" % routine.short_names_tested() + NL
+ if routine.no_scalars():
+ result += routine.routine_header_wrapper_cublas(routine.template, True, 23) + ";" + NL
+ for flavour in routine.flavours:
+ result += routine.routine_header_wrapper_cublas(flavour, False, 23) + " {" + NL
+
+ # There is a version available in cuBLAS
+ if flavour.precision_name in ["S", "D", "C", "Z"]:
+ indent = " " * (24 + routine.length())
+ arguments = routine.arguments_wrapper_cublas(flavour)
+ result += " cublasHandle_t handle;" + NL
+ result += " auto status = cublas" + flavour.name + routine.name + "(handle, "
+ result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL
+ result += " cublasDestroy(handle);" + NL
+ result += " return status;"
+
+ # There is no cuBLAS available, forward the call to one of the available functions
+ else: # Half-precision
+ result += " return CUBLAS_STATUS_NOT_SUPPORTED;"
+ # indent = " " * (24 + routine.length())
+
+ # # Convert to float (note: also integer buffers are stored as half/float)
+ # for buf in routine.inputs + routine.outputs:
+ # result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL
+
+ # # Call the float routine
+ # result += " cublasHandle_t handle;" + NL
+ # result += " auto status = cublasX" + routine.name + "(handle,"
+ # result += ("," + NL + indent).join([a for a in routine.arguments_half()]) + ");" + NL
+ # result += " cublasDestroy(handle);" + NL
+ # result += " return status;" + NL
+
+ # # Convert back to half
+ # for buf in routine.outputs:
+ # result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL
+ # result += " return status;"
+
+ # Complete
+ result += NL + "}" + NL
+ return result
+
+
def performance_test(routine, level_string):
"""Generates the body of a performance test for a specific routine"""
result = ""