diff options
-rwxr-xr-x | scripts/generator/generator.py | 2 | ||||
-rw-r--r-- | src/pyclblast/samples/haxpy.py | 5 | ||||
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 36 |
3 files changed, 40 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 1e274abd..76c5dc1c 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -49,7 +49,7 @@ FILES = [ "/src/clblast_cuda.cpp", "/src/pyclblast/src/pyclblast.pyx" ] -HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291] +HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327] FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 232 diff --git a/src/pyclblast/samples/haxpy.py b/src/pyclblast/samples/haxpy.py index d6c1fef0..3db87a34 100644 --- a/src/pyclblast/samples/haxpy.py +++ b/src/pyclblast/samples/haxpy.py @@ -13,7 +13,8 @@ import pyclblast # Settings for this sample dtype = 'float16' -alpha = np.array(1.0).astype(dtype=dtype).item() +alpha = 1.5 +alpha_fp16 = pyclblast.float32_to_float16(alpha) n = 4 print("# Setting up OpenCL") @@ -31,7 +32,7 @@ clx.set(x) cly.set(y) print("# Example level-1 operation: AXPY") -pyclblast.axpy(queue, n, clx, cly, alpha=alpha) +pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16) queue.finish() print("# Result for vector y: %s" % cly.get()) print("# Expected result: %s" % (alpha * x + y)) diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index fc9f437d..14efcf8a 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -11,6 +11,8 @@ # #################################################################################################### +import binascii +import struct import numpy as np import pyopencl as cl from pyopencl.array import Array @@ -288,6 +290,40 @@ def check_matrix(a, name): def check_vector(a, name): check_array(a, 1, name) +#################################################################################################### +# Half-precision utility functions +#################################################################################################### + +def float32_to_float16(float32): + # Taken from https://gamedev.stackexchange.com/a/28756 + F16_EXPONENT_BITS = 0x1F + F16_EXPONENT_SHIFT = 10 + F16_EXPONENT_BIAS = 15 + F16_MANTISSA_BITS = 0x3ff + F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT) + F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT) + + a = struct.pack('>f', float32) + b = binascii.hexlify(a) + + f32 = int(b, 16) + sign = (f32 >> 16) & 0x8000 + exponent = ((f32 >> 23) & 0xff) - 127 + mantissa = f32 & 0x007fffff + + if exponent == 128: + f16 = sign | F16_MAX_EXPONENT + if mantissa: + f16 |= (mantissa & F16_MANTISSA_BITS) + elif exponent > 15: + f16 = sign | F16_MAX_EXPONENT + elif exponent > -15: + exponent += F16_EXPONENT_BIAS + mantissa >>= F16_MANTISSA_SHIFT + f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa + else: + f16 = sign + return f16 #################################################################################################### # Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP |