diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/generator/generator/cpp.py | 11 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 2 |
2 files changed, 5 insertions, 8 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 79d6b2a1..17e418e3 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -318,11 +318,9 @@ def wrapper_cublas(routine): result += " " + scalar + "_cuda.y = " + scalar + ".imag();" + NL # Calls the cuBLAS routine - result += " cublasHandle_t handle;" + NL - result += " if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) { return CUBLAS_STATUS_NOT_INITIALIZED; }" + NL result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, " result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL - result += " cublasDestroy(handle);" + NL + result += " cudaDeviceSynchronize();" + NL result += " return status;" # There is no cuBLAS available, forward the call to one of the available functions @@ -335,11 +333,10 @@ def wrapper_cublas(routine): # result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL # # Call the float routine - # result += " cublasHandle_t handle;" + NL - # result += " auto status = cublasX" + routine.name + "(handle," + # result += " return cublasX" + routine.name + "(handle," # result += ("," + NL + indent).join([a for a in routine.arguments_half()]) + ");" + NL - # result += " cublasDestroy(handle);" + NL - # result += " return status;" + NL + # result += " cudaDeviceSynchronize();" + NL + # result += " return status;" # # Convert back to half # for buf in routine.outputs: diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index a7abfde5..1c534611 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -884,6 +884,6 @@ class Routine: if def_only: result += flavour.name result += ">\n" - result += "cublasStatus_t cublasX" + self.name + template + "(" + result += "cublasStatus_t cublasX" + self.name + template + "(cublasHandle_t handle, " result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cublas(flavour)]) + ")" return result |