diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/generator/generator/cpp.py | 1 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 12 |
2 files changed, 11 insertions, 2 deletions
diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index 7c695dc8..79d6b2a1 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -319,6 +319,7 @@ def wrapper_cublas(routine): # Calls the cuBLAS routine result += " cublasHandle_t handle;" + NL + result += " if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) { return CUBLAS_STATUS_NOT_INITIALIZED; }" + NL result += " auto status = cublas" + flavour.name_cublas() + routine.name + "(handle, " result += ("," + NL + indent).join([a for a in arguments]) + ");" + NL result += " cublasDestroy(handle);" + NL diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index b1db484f..a7abfde5 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -347,7 +347,12 @@ class Routine: """As above but for cuBLAS the wrapper""" prefix = "const " if name in self.inputs else "" if name in self.inputs or name in self.outputs: - if flavour.precision_name in ["C", "Z"]: + if name in self.index_buffers(): + a = ["reinterpret_cast<int*>(&" + name + "_buffer[" + name + "_offset])"] + elif name in self.outputs and flavour.name in ["Sc", "Dz"]: + dtype = "float" if flavour.name == "Sc" else "double" + a = ["reinterpret_cast<" + dtype + "*>(&" + name + "_buffer[" + name + "_offset])"] + elif flavour.precision_name in ["C", "Z"]: cuda_complex = "cuDoubleComplex" if flavour.precision_name == "Z" else "cuComplex" a = ["reinterpret_cast<" + prefix + cuda_complex + "*>" + "(&" + name + "_buffer[" + name + "_offset])"] @@ -358,7 +363,10 @@ class Routine: c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"] elif name in ["a", "b", "c"]: c = [name + "_" + self.postfix(name)] - return [", ".join(a + c)] + result = [", ".join(a + c)] + if self.name == "trmm" and name == "a": + result *= 2 + return result return [] def buffer_type(self, name): |