diff options
Diffstat (limited to 'src/pyclblast')
-rw-r--r-- | src/pyclblast/samples/haxpy.py | 5 | ||||
-rw-r--r-- | src/pyclblast/src/pyclblast.pyx | 36 |
2 files changed, 39 insertions, 2 deletions
diff --git a/src/pyclblast/samples/haxpy.py b/src/pyclblast/samples/haxpy.py index d6c1fef0..3db87a34 100644 --- a/src/pyclblast/samples/haxpy.py +++ b/src/pyclblast/samples/haxpy.py @@ -13,7 +13,8 @@ import pyclblast # Settings for this sample dtype = 'float16' -alpha = np.array(1.0).astype(dtype=dtype).item() +alpha = 1.5 +alpha_fp16 = pyclblast.float32_to_float16(alpha) n = 4 print("# Setting up OpenCL") @@ -31,7 +32,7 @@ clx.set(x) cly.set(y) print("# Example level-1 operation: AXPY") -pyclblast.axpy(queue, n, clx, cly, alpha=alpha) +pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16) queue.finish() print("# Result for vector y: %s" % cly.get()) print("# Expected result: %s" % (alpha * x + y)) diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index fc9f437d..14efcf8a 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -11,6 +11,8 @@ # #################################################################################################### +import binascii +import struct import numpy as np import pyopencl as cl from pyopencl.array import Array @@ -288,6 +290,40 @@ def check_matrix(a, name): def check_vector(a, name): check_array(a, 1, name) +#################################################################################################### +# Half-precision utility functions +#################################################################################################### + +def float32_to_float16(float32): + # Taken from https://gamedev.stackexchange.com/a/28756 + F16_EXPONENT_BITS = 0x1F + F16_EXPONENT_SHIFT = 10 + F16_EXPONENT_BIAS = 15 + F16_MANTISSA_BITS = 0x3ff + F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT) + F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT) + + a = struct.pack('>f', float32) + b = binascii.hexlify(a) + + f32 = int(b, 16) + sign = (f32 >> 16) & 0x8000 + exponent = ((f32 >> 23) & 0xff) - 127 + mantissa = f32 & 0x007fffff + + if exponent == 128: + f16 = sign | F16_MAX_EXPONENT + if mantissa: + f16 |= (mantissa & F16_MANTISSA_BITS) + elif exponent > 15: + f16 = sign | F16_MAX_EXPONENT + elif exponent > -15: + exponent += F16_EXPONENT_BIAS + mantissa >>= F16_MANTISSA_SHIFT + f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa + else: + f16 = sign + return f16 #################################################################################################### # Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP |