summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2019-01-22 21:13:41 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2019-01-22 21:13:41 +0100
commit3937efdcda30843eb6b9f4122482800593fa7822 (patch)
tree47407c9b8b64dcd39f3f252ef2dcb5a024597ef9 /scripts
parent9a9c24e811ddefb6e9d462288916ff64dbf47d63 (diff)
Added experimental support for half-precision in pyclblast
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--scripts/generator/generator/pyclblast.py7
2 files changed, 6 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 68e3f01a..1e274abd 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -49,7 +49,7 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 290]
+HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py
index ab719f5e..47eb2eb4 100644
--- a/scripts/generator/generator/pyclblast.py
+++ b/scripts/generator/generator/pyclblast.py
@@ -18,6 +18,7 @@ def to_np_dtype(flavour):
"D": "float64",
"C": "complex64",
"Z": "complex128",
+ "H": "float16",
}[flavour.precision_name]
@@ -31,6 +32,8 @@ def scalar_cython_conversion(scalar, flavour):
return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)"
if scalar_type in ["cl_double2", "double2"]:
return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)"
+ if scalar_type in ["cl_half", "half"]:
+ return "<cl_half>" + scalar
raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type))
@@ -48,7 +51,7 @@ def generate_pyx(routine):
result += "cdef extern from \"clblast_c.h\":" + NL
np_dtypes = []
for flavour in routine.flavours:
- if flavour.precision_name in ["S", "D", "C", "Z"]:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "("
result += ", ".join(routine.arguments_def_c(flavour)) + ","
result += "cl_command_queue* queue, cl_event* event)" + NL
@@ -103,7 +106,7 @@ def generate_pyx(routine):
result += indent + "cdef CLBlastStatusCode err" + NL
if_prefix = ""
for flavour in routine.flavours:
- if flavour.precision_name in ["S", "D", "C", "Z"]:
+ if flavour.precision_name in ["S", "D", "C", "Z", "H"]:
np_dtype = to_np_dtype(flavour)
argument_names = [x.
replace("layout", "CLBlastLayoutRowMajor").