diff options
Diffstat (limited to 'scripts/generator/generator/cpp.py')
-rw-r--r-- | scripts/generator/generator/cpp.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 03da7985..49240095 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -290,6 +290,52 @@ def wrapper_cblas(routine): return result +def wrapper_cublas(routine): + """The wrapper to the reference cuBLAS routines (for performance/correctness testing)""" + result = "" + if routine.has_tests: + result += NL + "// Forwards the cuBLAS calls for %s" % routine.short_names_tested() + NL + if routine.no_scalars(): + result += routine.routine_header_wrapper_cublas(routine.template, True, 23) + ";" + NL + for flavour in routine.flavours: + result += routine.routine_header_wrapper_cublas(flavour, False, 23) + " {" + NL + + # There is a version available in cuBLAS + if flavour.precision_name in ["S", "D", "C", "Z"]: + indent = " " * (24 + routine.length()) + arguments = routine.arguments_wrapper_cublas(flavour) + result += " cublasHandle_t handle;" + NL + result += " auto status = cublas" + flavour.name + routine.name + "(handle, " + result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL + result += " cublasDestroy(handle);" + NL + result += " return status;" + + # There is no cuBLAS available, forward the call to one of the available functions + else: # Half-precision + result += " return CUBLAS_STATUS_NOT_SUPPORTED;" + # indent = " " * (24 + routine.length()) + + # # Convert to float (note: also integer buffers are stored as half/float) + # for buf in routine.inputs + routine.outputs: + # result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL + + # # Call the float routine + # result += " cublasHandle_t handle;" + NL + # result += " auto status = cublasX" + routine.name + "(handle," + # result += ("," + NL + indent).join([a for a in routine.arguments_half()]) + ");" + NL + # result += " cublasDestroy(handle);" + NL + # result += " return status;" + NL + + # # Convert back to half + # for buf in routine.outputs: + # result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL + # result += " return status;" + + # Complete + result += NL + "}" + NL + return result + + def performance_test(routine, level_string): """Generates the body of a performance test for a specific routine""" result = "" |