From 3937efdcda30843eb6b9f4122482800593fa7822 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Tue, 22 Jan 2019 21:13:41 +0100 Subject: Added experimental support for half-precision in pyclblast --- scripts/generator/generator.py | 2 +- scripts/generator/generator/pyclblast.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'scripts') diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 68e3f01a..1e274abd 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -49,7 +49,7 @@ FILES = [ "/src/clblast_cuda.cpp", "/src/pyclblast/src/pyclblast.pyx" ] -HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 290] +HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291] FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 232 diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index ab719f5e..47eb2eb4 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -18,6 +18,7 @@ def to_np_dtype(flavour): "D": "float64", "C": "complex64", "Z": "complex128", + "H": "float16", }[flavour.precision_name] @@ -31,6 +32,8 @@ def scalar_cython_conversion(scalar, flavour): return "cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)" if scalar_type in ["cl_double2", "double2"]: return "cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)" + if scalar_type in ["cl_half", "half"]: + return "" + scalar raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type)) @@ -48,7 +51,7 @@ def generate_pyx(routine): result += "cdef extern from \"clblast_c.h\":" + NL np_dtypes = [] for flavour in routine.flavours: - if flavour.precision_name in ["S", "D", "C", "Z"]: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "(" result += ", ".join(routine.arguments_def_c(flavour)) + "," result += "cl_command_queue* queue, cl_event* event)" + NL @@ -103,7 +106,7 @@ def generate_pyx(routine): result += indent + "cdef CLBlastStatusCode err" + NL if_prefix = "" for flavour in routine.flavours: - if flavour.precision_name in ["S", "D", "C", "Z"]: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: np_dtype = to_np_dtype(flavour) argument_names = [x. replace("layout", "CLBlastLayoutRowMajor"). -- cgit v1.2.3