summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-02-18 16:33:20 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-02-18 16:33:20 +0100
commite1bfb4082716ef9619a13e9985aca9ef28cf4cbf (patch)
tree94178e2501add7fd8a3ad683ef3ba2ed1e7cafd8 /scripts/generator
parenteb85f6b514b285d7dde1ac02b97b7581a46ff21d (diff)
Added GEMM to the Python wrapper
Diffstat (limited to 'scripts/generator')
-rw-r--r--scripts/generator/generator/convert.py13
-rw-r--r--scripts/generator/generator/pyclblast.py37
-rw-r--r--scripts/generator/generator/routine.py17
3 files changed, 49 insertions, 18 deletions
diff --git a/scripts/generator/generator/convert.py b/scripts/generator/generator/convert.py
index 44eb69d6..07f45669 100644
--- a/scripts/generator/generator/convert.py
+++ b/scripts/generator/generator/convert.py
@@ -80,16 +80,3 @@ def option_to_documentation(x):
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
}[x]
-
-
-def option_to_clblastdefault(x):
- """Translates an option name to a CLBlast C default type"""
- return {
- 'layout': "CLBlastLayoutColMajor",
- 'a_transpose': "CLBlastTransposeNo",
- 'b_transpose': "CLBlastTransposeNo",
- 'ab_transpose': "CLBlastTransposeNo",
- 'side': "CLBlastSideLeft",
- 'triangle': "CLBlastTriangleUpper",
- 'diagonal': "CLBlastDiagonalNonUnit",
- }[x]
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index 089a410a..ffeaab8d 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -21,9 +21,21 @@ def to_np_dtype(flavour):
raise RuntimeError("Could not convert flavour '%s' to numpy" % flavour.precision_name)
+def scalar_cython_conversion(scalar, flavour):
+ if flavour.precision_name == "S":
+ return "<cl_float>" + scalar
+ if flavour.precision_name == "D":
+ return "<cl_double>" + scalar
+ if flavour.precision_name == "C":
+ return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if flavour.precision_name == "Z":
+ return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ raise RuntimeError("Could not convert flavour '%s'" % flavour.precision_name)
+
+
def generate_pyx(routine):
result = ""
- if routine.implemented and routine.plain_name() == "swap": # TODO: Generalize
+ if routine.implemented and routine.plain_name() in ["swap", "gemm"]: # TODO: Generalize
result += SEPARATOR + NL
result += "# " + routine.description + ": " + routine.short_names() + NL
@@ -59,16 +71,35 @@ def generate_pyx(routine):
result += " cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr" + NL
result += " cdef cl_event event = NULL" + NL
- result += "" + NL
+ for option in routine.options:
+ if option == "a_transpose":
+ result += " a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo" + NL
+ if option == "b_transpose":
+ result += " b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo" + NL
+ if option == "ab_transpose":
+ result += " ab_transpose = CLBlastTransposeYes if ab_transp else CLBlastTransposeNo" + NL
+ if option == "side":
+ result += " side = CLBlastSideRight if right_side else CLBlastSideLeft" + NL
+ if option == "triangle":
+ result += " triangle = CLBlastTriangleLower if lower_triangle else CLBlastTriangleUpper" + NL
+ if option == "diagonal":
+ result += " diagonal = CLBlastDiagonalUnit if unit_diagonal else CLBlastDiagonalNonUnit" + NL
+
+ result += "" + NL
result += " cdef CLBlastStatusCode err" + NL
if_prefix = ""
for flavour in routine.flavours:
if flavour.precision_name in ["S", "D", "C", "Z"]:
np_dtype = to_np_dtype(flavour)
+ argument_names = [x.
+ replace("layout", "CLBlastLayoutRowMajor").
+ replace("alpha", scalar_cython_conversion("alpha", flavour)).
+ replace("beta", scalar_cython_conversion("beta", flavour))
+ for x in routine.arguments()]
result += " " + if_prefix + "if dtype == np.dtype(\"" + np_dtype + "\"):" + NL
result += " err = CLBlast" + flavour.name + routine.plain_name()
- result += "(" + ", ".join(routine.arguments()) + ", &command_queue, &event)" + NL
+ result += "(" + ", ".join(argument_names) + ", &command_queue, &event)" + NL
if_prefix = "el"
result += " else:" + NL
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index d0b0a6d7..c52f49ca 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -827,9 +827,22 @@ class Routine:
for buf in buffers:
if buf in self.buffers_vector():
result.append(buf + "_inc = 1")
+ for scalar in self.scalars:
+ default = "1.0" if scalar == "alpha" else "0.0"
+ result.append(scalar + " = " + default)
for option in self.options:
- default = convert.option_to_clblastdefault(option)
- result.append(option + " = " + default)
+ if option == "a_transpose":
+ result.append("a_transp = False")
+ if option == "b_transpose":
+ result.append("b_transp = False")
+ if option == "ab_transpose":
+ result.append("ab_transp = False")
+ if option == "side":
+ result.append("right_side = False")
+ if option == "triangle":
+ result.append("lower_triangle = False")
+ if option == "diagonal":
+ result.append("unit_diagonal = False")
for buf in buffers:
result.append(buf + "_offset = 0")
return result