summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2019-01-23 19:52:01 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2019-01-23 19:52:01 +0100
commite0541c41a17bd500ab3f03bbb9a934d9cc3b0a75 (patch)
treeeff8b433e24ebfa23dd0a2c4207381059373b81c
parent347f0df32f0ddcc673e1e62f299090ac60b240a4 (diff)
Added fp32 to fp16 conversion function in Python to make haxpy example work
-rwxr-xr-xscripts/generator/generator.py2
-rw-r--r--src/pyclblast/samples/haxpy.py5
-rw-r--r--src/pyclblast/src/pyclblast.pyx36
3 files changed, 40 insertions, 3 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 1e274abd..76c5dc1c 100755
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -49,7 +49,7 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
-HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291]
+HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232
diff --git a/src/pyclblast/samples/haxpy.py b/src/pyclblast/samples/haxpy.py
index d6c1fef0..3db87a34 100644
--- a/src/pyclblast/samples/haxpy.py
+++ b/src/pyclblast/samples/haxpy.py
@@ -13,7 +13,8 @@ import pyclblast
# Settings for this sample
dtype = 'float16'
-alpha = np.array(1.0).astype(dtype=dtype).item()
+alpha = 1.5
+alpha_fp16 = pyclblast.float32_to_float16(alpha)
n = 4
print("# Setting up OpenCL")
@@ -31,7 +32,7 @@ clx.set(x)
cly.set(y)
print("# Example level-1 operation: AXPY")
-pyclblast.axpy(queue, n, clx, cly, alpha=alpha)
+pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16)
queue.finish()
print("# Result for vector y: %s" % cly.get())
print("# Expected result: %s" % (alpha * x + y))
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx
index fc9f437d..14efcf8a 100644
--- a/src/pyclblast/src/pyclblast.pyx
+++ b/src/pyclblast/src/pyclblast.pyx
@@ -11,6 +11,8 @@
#
####################################################################################################
+import binascii
+import struct
import numpy as np
import pyopencl as cl
from pyopencl.array import Array
@@ -288,6 +290,40 @@ def check_matrix(a, name):
def check_vector(a, name):
check_array(a, 1, name)
+####################################################################################################
+# Half-precision utility functions
+####################################################################################################
+
+def float32_to_float16(float32):
+ # Taken from https://gamedev.stackexchange.com/a/28756
+ F16_EXPONENT_BITS = 0x1F
+ F16_EXPONENT_SHIFT = 10
+ F16_EXPONENT_BIAS = 15
+ F16_MANTISSA_BITS = 0x3ff
+ F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT)
+ F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT)
+
+ a = struct.pack('>f', float32)
+ b = binascii.hexlify(a)
+
+ f32 = int(b, 16)
+ sign = (f32 >> 16) & 0x8000
+ exponent = ((f32 >> 23) & 0xff) - 127
+ mantissa = f32 & 0x007fffff
+
+ if exponent == 128:
+ f16 = sign | F16_MAX_EXPONENT
+ if mantissa:
+ f16 |= (mantissa & F16_MANTISSA_BITS)
+ elif exponent > 15:
+ f16 = sign | F16_MAX_EXPONENT
+ elif exponent > -15:
+ exponent += F16_EXPONENT_BIAS
+ mantissa >>= F16_MANTISSA_SHIFT
+ f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa
+ else:
+ f16 = sign
+ return f16
####################################################################################################
# Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP