summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2018-03-11 15:32:36 +0100
committerCedric Nugteren <web@cedricnugteren.nl>2018-03-11 15:32:36 +0100
commitbcf12084319ed6eb687e2308fcb050eaad7c95ec (patch)
tree6a168214173f07acee33ddb0f16f090745a69ea8
parent0dd1bc6f4880308d3e240545203413d3d902d5b7 (diff)
Added basic tests for PyCLBlast
-rw-r--r--src/pyclblast/README.md8
-rw-r--r--src/pyclblast/test/__init__.py0
-rw-r--r--src/pyclblast/test/test_pyclblast.py77
3 files changed, 85 insertions, 0 deletions
diff --git a/src/pyclblast/README.md b/src/pyclblast/README.md
index be37af01..2f6ebed7 100644
--- a/src/pyclblast/README.md
+++ b/src/pyclblast/README.md
@@ -29,3 +29,11 @@ After installation OpenCL and CLBlast, simply use pip to install PyCLBlast, e.g.
pip install --user pyclblast
To start using the library, browse the [CLBlast](https://github.com/CNugteren/CLBlast) documentation or check out the PyCLBlast samples provides in the `samples` subfolder.
+
+
+Testing PyCLBlast
+-------------
+
+The main exhaustive tests are the main CLBlast test binaries. Apart from that, you can also run the PyCLBlast smoke tests from the `test` subfolder, e.g. as follows:
+
+ python -m unittest discover
diff --git a/src/pyclblast/test/__init__.py b/src/pyclblast/test/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/pyclblast/test/__init__.py
diff --git a/src/pyclblast/test/test_pyclblast.py b/src/pyclblast/test/test_pyclblast.py
new file mode 100644
index 00000000..9f2e2b47
--- /dev/null
+++ b/src/pyclblast/test/test_pyclblast.py
@@ -0,0 +1,77 @@
+
+####################################################################################################
+# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0.
+#
+# Author(s):
+# Cedric Nugteren <www.cedricnugteren.nl>
+#
+# This file test PyCLBlast: the Python interface to CLBlast. It is not exhaustive. For full testing
+# it is recommended to run the regular CLBlast tests, this is just a small smoke test.
+#
+####################################################################################################
+
+import unittest
+
+import numpy as np
+import pyopencl as cl
+from pyopencl.array import Array
+
+import pyclblast
+
+
+class TestPyCLBlast(unittest.TestCase):
+
+ @staticmethod
+ def setup(sizes, dtype):
+ ctx = cl.create_some_context()
+ queue = cl.CommandQueue(ctx)
+ host_arrays, device_arrays = [], []
+ for size in sizes:
+ numpy_array = np.random.rand(*size).astype(dtype=dtype)
+ opencl_array = Array(queue, numpy_array.shape, numpy_array.dtype)
+ opencl_array.set(numpy_array)
+ host_arrays.append(numpy_array)
+ device_arrays.append(opencl_array)
+ return queue, host_arrays, device_arrays
+
+ def test_axpy(self):
+ for dtype in ["float32", "complex64"]:
+ for alpha in [1.0, 3.1]:
+ for n in [1, 7, 32]:
+ queue, h, d = self.setup([(n,), (n,)], dtype=dtype)
+ pyclblast.axpy(queue, n, d[0], d[1], alpha=alpha)
+ result = d[1].get()
+ reference = alpha * h[0] + h[1]
+ for i in range(n):
+ self.assertAlmostEqual(reference[i], result[i], places=3)
+
+ def test_gemv(self):
+ for dtype in ["float32", "complex64"]:
+ for beta in [1.0]:
+ for alpha in [1.0, 3.1]:
+ for m in [1, 7, 32]:
+ for n in [1, 7, 32]:
+ queue, h, d = self.setup([(m, n), (n,), (m,)], dtype=dtype)
+ pyclblast.gemv(queue, m, n, d[0], d[1], d[2],
+ a_ld=n, alpha=alpha, beta=beta)
+ result = d[2].get()
+ reference = alpha * np.dot(h[0], h[1]) + beta * h[2]
+ for i in range(m):
+ self.assertAlmostEqual(reference[i], result[i], places=3)
+
+ def test_gemm(self):
+ for dtype in ["float32", "complex64"]:
+ for beta in [1.0]:
+ for alpha in [1.0, 3.1]:
+ for m in [1, 7, 32]:
+ for n in [1, 7, 32]:
+ for k in [1, 7, 32]:
+ queue, h, d = self.setup([(m, k), (k, n), (m, n)], dtype=dtype)
+ pyclblast.gemm(queue, m, n, k, d[0], d[1], d[2],
+ a_ld=k, b_ld=n, c_ld=n, alpha=alpha, beta=beta)
+ result = d[2].get()
+ reference = alpha * np.dot(h[0], h[1]) + beta * h[2]
+ for i in range(m):
+ for j in range(n):
+ self.assertAlmostEqual(reference[i, j], result[i, j],
+ places=3)