summaryrefslogtreecommitdiff
path: root/src/pyclblast
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2019-01-23 19:52:01 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2019-01-23 19:52:01 +0100
commite0541c41a17bd500ab3f03bbb9a934d9cc3b0a75 (patch)
treeeff8b433e24ebfa23dd0a2c4207381059373b81c /src/pyclblast
parent347f0df32f0ddcc673e1e62f299090ac60b240a4 (diff)
Added fp32 to fp16 conversion function in Python to make haxpy example work
Diffstat (limited to 'src/pyclblast')
-rw-r--r--src/pyclblast/samples/haxpy.py5
-rw-r--r--src/pyclblast/src/pyclblast.pyx36
2 files changed, 39 insertions, 2 deletions
diff --git a/src/pyclblast/samples/haxpy.py b/src/pyclblast/samples/haxpy.py
index d6c1fef0..3db87a34 100644
--- a/src/pyclblast/samples/haxpy.py
+++ b/src/pyclblast/samples/haxpy.py
@@ -13,7 +13,8 @@ import pyclblast
# Settings for this sample
dtype = 'float16'
-alpha = np.array(1.0).astype(dtype=dtype).item()
+alpha = 1.5
+alpha_fp16 = pyclblast.float32_to_float16(alpha)
n = 4
print("# Setting up OpenCL")
@@ -31,7 +32,7 @@ clx.set(x)
cly.set(y)
print("# Example level-1 operation: AXPY")
-pyclblast.axpy(queue, n, clx, cly, alpha=alpha)
+pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16)
queue.finish()
print("# Result for vector y: %s" % cly.get())
print("# Expected result: %s" % (alpha * x + y))
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx
index fc9f437d..14efcf8a 100644
--- a/src/pyclblast/src/pyclblast.pyx
+++ b/src/pyclblast/src/pyclblast.pyx
@@ -11,6 +11,8 @@
#
####################################################################################################
+import binascii
+import struct
import numpy as np
import pyopencl as cl
from pyopencl.array import Array
@@ -288,6 +290,40 @@ def check_matrix(a, name):
def check_vector(a, name):
check_array(a, 1, name)
+####################################################################################################
+# Half-precision utility functions
+####################################################################################################
+
+def float32_to_float16(float32):
+ # Taken from https://gamedev.stackexchange.com/a/28756
+ F16_EXPONENT_BITS = 0x1F
+ F16_EXPONENT_SHIFT = 10
+ F16_EXPONENT_BIAS = 15
+ F16_MANTISSA_BITS = 0x3ff
+ F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT)
+ F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT)
+
+ a = struct.pack('>f', float32)
+ b = binascii.hexlify(a)
+
+ f32 = int(b, 16)
+ sign = (f32 >> 16) & 0x8000
+ exponent = ((f32 >> 23) & 0xff) - 127
+ mantissa = f32 & 0x007fffff
+
+ if exponent == 128:
+ f16 = sign | F16_MAX_EXPONENT
+ if mantissa:
+ f16 |= (mantissa & F16_MANTISSA_BITS)
+ elif exponent > 15:
+ f16 = sign | F16_MAX_EXPONENT
+ elif exponent > -15:
+ exponent += F16_EXPONENT_BIAS
+ mantissa >>= F16_MANTISSA_SHIFT
+ f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa
+ else:
+ f16 = sign
+ return f16
####################################################################################################
# Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP