diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-18 16:33:20 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2018-02-18 16:33:20 +0100 |
commit | e1bfb4082716ef9619a13e9985aca9ef28cf4cbf (patch) | |
tree | 94178e2501add7fd8a3ad683ef3ba2ed1e7cafd8 /scripts/generator | |
parent | eb85f6b514b285d7dde1ac02b97b7581a46ff21d (diff) |
Added GEMM to the Python wrapper
Diffstat (limited to 'scripts/generator')
-rw-r--r-- | scripts/generator/generator/convert.py | 13 | ||||
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 37 | ||||
-rw-r--r-- | scripts/generator/generator/routine.py | 17 |
3 files changed, 49 insertions, 18 deletions
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py index 44eb69d6..07f45669 100644 --- a/scripts/generator/generator/convert.py +++ b/scripts/generator/generator/convert.py @@ -80,16 +80,3 @@ def option_to_documentation(x): 'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).", 'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.", }[x] - - -def option_to_clblastdefault(x): - """Translates an option name to a CLBlast C default type""" - return { - 'layout': "CLBlastLayoutColMajor", - 'a_transpose': "CLBlastTransposeNo", - 'b_transpose': "CLBlastTransposeNo", - 'ab_transpose': "CLBlastTransposeNo", - 'side': "CLBlastSideLeft", - 'triangle': "CLBlastTriangleUpper", - 'diagonal': "CLBlastDiagonalNonUnit", - }[x] diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index 089a410a..ffeaab8d 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -21,9 +21,21 @@ def to_np_dtype(flavour): raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name) +def scalar_cython_conversion(scalar, flavour): + if flavour.precision_name == "S": + return "<cl_float>" + scalar + if flavour.precision_name == "D": + return "<cl_double>" + scalar + if flavour.precision_name == "C": + return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)" + if flavour.precision_name == "Z": + return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)" + raise RuntimeError("Could not convert flavour '%s'" % flavour.precision_name) + + def generate_pyx(routine): result = "" - if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize + if routine.implemented and routine.plain_name() in ["swap", "gemm"]: # TODO: Generalize result += SEPARATOR + NL result += "# " + routine.description + ": " + routine.short_names() + NL @@ -59,16 +71,35 @@ def generate_pyx(routine): result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL result += " cdef cl_event event = NULL" + NL - result += "" + NL + for option in routine.options: + if option == "a_transpose": + result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL + if option == "b_transpose": + result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL + if option == "ab_transpose": + result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL + if option == "side": + result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL + if option == "triangle": + result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL + if option == "diagonal": + result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL + + result += "" + NL result += " cdef CLBlastStatusCode err" + NL if_prefix = "" for flavour in routine.flavours: if flavour.precision_name in ["S", "D", "C", "Z"]: np_dtype = to_np_dtype(flavour) + argument_names = [x. + replace("layout", "CLBlastLayoutRowMajor"). + replace("alpha", scalar_cython_conversion("alpha", flavour)). + replace("beta", scalar_cython_conversion("beta", flavour)) + for x in routine.arguments()] result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL result += " err = CLBlast" + flavour.name + routine.plain_name() - result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL + result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL if_prefix = "el" result += " else:" + NL diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index d0b0a6d7..c52f49ca 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -827,9 +827,22 @@ class Routine: for buf in buffers: if buf in self.buffers_vector(): result.append(buf + "_inc = 1") + for scalar in self.scalars: + default = "1.0" if scalar == "alpha" else "0.0" + result.append(scalar + " = " + default) for option in self.options: - default = convert.option_to_clblastdefault(option) - result.append(option + " = " + default) + if option == "a_transpose": + result.append("a_transp = False") + if option == "b_transpose": + result.append("b_transp = False") + if option == "ab_transpose": + result.append("ab_transp = False") + if option == "side": + result.append("right_side = False") + if option == "triangle": + result.append("lower_triangle = False") + if option == "diagonal": + result.append("unit_diagonal = False") for buf in buffers: result.append(buf + "_offset = 0") return result |