diff options
author | Cedric Nugteren <web@cedricnugteren.nl> | 2019-01-22 21:13:41 +0100 |
---|---|---|
committer | Cedric Nugteren <web@cedricnugteren.nl> | 2019-01-22 21:13:41 +0100 |
commit | 3937efdcda30843eb6b9f4122482800593fa7822 (patch) | |
tree | 47407c9b8b64dcd39f3f252ef2dcb5a024597ef9 /scripts/generator/generator | |
parent | 9a9c24e811ddefb6e9d462288916ff64dbf47d63 (diff) |
Added experimental support for half-precision in pyclblast
Diffstat (limited to 'scripts/generator/generator')
-rw-r--r-- | scripts/generator/generator/pyclblast.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/scripts/generator/generator/pyclblast.py b/scripts/generator/generator/pyclblast.py index ab719f5e..47eb2eb4 100644 --- a/scripts/generator/generator/pyclblast.py +++ b/scripts/generator/generator/pyclblast.py @@ -18,6 +18,7 @@ def to_np_dtype(flavour): "D": "float64", "C": "complex64", "Z": "complex128", + "H": "float16", }[flavour.precision_name] @@ -31,6 +32,8 @@ def scalar_cython_conversion(scalar, flavour): return "<cl_float2>cl_float2(x=" + scalar + ".real,y=" + scalar + ".imag)" if scalar_type in ["cl_double2", "double2"]: return "<cl_double2>cl_double2(x=" + scalar + ".real,y=" + scalar + ".imag)" + if scalar_type in ["cl_half", "half"]: + return "<cl_half>" + scalar raise RuntimeError("Could not convert flavour '%s:%s'" % (flavour.precision_name, scalar_type)) @@ -48,7 +51,7 @@ def generate_pyx(routine): result += "cdef extern from \"clblast_c.h\":" + NL np_dtypes = [] for flavour in routine.flavours: - if flavour.precision_name in ["S", "D", "C", "Z"]: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: result += indent + "CLBlastStatusCode CLBlast" + flavour.name + routine.plain_name() + "(" result += ", ".join(routine.arguments_def_c(flavour)) + "," result += "cl_command_queue* queue, cl_event* event)" + NL @@ -103,7 +106,7 @@ def generate_pyx(routine): result += indent + "cdef CLBlastStatusCode err" + NL if_prefix = "" for flavour in routine.flavours: - if flavour.precision_name in ["S", "D", "C", "Z"]: + if flavour.precision_name in ["S", "D", "C", "Z", "H"]: np_dtype = to_np_dtype(flavour) argument_names = [x. replace("layout", "CLBlastLayoutRowMajor"). |