summaryrefslogtreecommitdiff
path: root/scripts/generator
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2019-01-26 11:04:14 +0100
committerGitHub <noreply@github.com>2019-01-26 11:04:14 +0100
commiteff0f9ad1d6dafa0c4519cc145cbf39511bf737d (patch)
treeeff8b433e24ebfa23dd0a2c4207381059373b81c /scripts/generator
parent9a9c24e811ddefb6e9d462288916ff64dbf47d63 (diff)
parente0541c41a17bd500ab3f03bbb9a934d9cc3b0a75 (diff)
Merge pull request #348 from CNugteren/CLBlast-334-pyclblast-half-precision-support
PyCLBlast half precision support
Diffstat (limited to 'scripts/generator')
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--scripts/generator/generator/pyclblast.py7
2 files changed, 6 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 68e3f01a..76c5dc1c 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -49,7 +49,7 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 290]
+HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index ab719f5e..47eb2eb4 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -18,6 +18,7 @@ def to_np_dtype(flavour):
"D": "float64",
"C": "complex64",
"Z": "complex128",
+ "H": "float16",
}[flavour.precision_name]
@@ -31,6 +32,8 @@ def scalar_cython_conversion(scalar, flavour):
return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
if scalar_type in ["cl_double2", "double2"]:
return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if scalar_type in ["cl_half", "half"]:
+ return "<cl_half>" + scalar
raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type))
@@ -48,7 +51,7 @@ def generate_pyx(routine):
result += "cdef extern from \"clblast_c.h\":" + NL
np_dtypes = []
for flavour in routine.flavours:
- if flavour.precision_name in ["S", "D", "C", "Z"]:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
result += ", ".join(routine.arguments_def_c(flavour)) + ","
result += "cl_command_queue* queue, cl_event* event)" + NL
@@ -103,7 +106,7 @@ def generate_pyx(routine):
result += indent + "cdef CLBlastStatusCode err" + NL
if_prefix = ""
for flavour in routine.flavours:
- if flavour.precision_name in ["S", "D", "C", "Z"]:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
np_dtype = to_np_dtype(flavour)
argument_names = [x.
replace("layout", "CLBlastLayoutRowMajor").