From 58e70c56f15497d6ff4a042048a241f93cebe38d Mon Sep 17 00:00:00 2001 From: kodonell Date: Mon, 26 Mar 2018 08:51:55 +1300 Subject: tidying up pyclblast override_parameters api, and added example --- src/pyclblast/samples/sgemm.py | 19 ++++++++++++++++--- src/pyclblast/src/pyclblast.pyx | 26 +++++++++++++------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/pyclblast/samples/sgemm.py b/src/pyclblast/samples/sgemm.py index c872553f..3f149159 100644 --- a/src/pyclblast/samples/sgemm.py +++ b/src/pyclblast/samples/sgemm.py @@ -10,6 +10,7 @@ import numpy as np import pyopencl as cl from pyopencl.array import Array import pyclblast +from datetime import datetime # Settings for this sample dtype = 'float32' @@ -19,7 +20,7 @@ ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) print("# Setting up Numpy arrays") -m, n, k = 2, 3, 4 +m, n, k = 128, 256, 512 a = np.random.rand(m, k).astype(dtype=dtype) b = np.random.rand(k, n).astype(dtype=dtype) c = np.random.rand(m, n).astype(dtype=dtype) @@ -34,5 +35,17 @@ clc.set(c) print("# Example level-3 operation: GEMM") pyclblast.gemm(queue, m, n, k, cla, clb, clc, a_ld=k, b_ld=n, c_ld=n) -print("# Matrix C result: %s" % clc.get()) -print("# Expected result: %s" % (np.dot(a, b))) +print("# PyCLBlast matrix result is correct?:", np.allclose(clc.get(), np.dot(a, b))) + +print("# GFLOPS when tuned with different values of MWG:") +params = { "KWG": 32, "KWI": 2, "MDIMA": 8, "MDIMC": 8, "MWG": 64, "NDIMB": 8, "NDIMC": 8, "NWG": 64, "SA": 0, "SB": 0, "STRM": 0, "STRN": 0, "VWM": 4, "VWN": 1 } +mwg = 1 +while mwg <= 256: + params["MWG"] = mwg + pyclblast.override_parameters(ctx.devices[0], 'Xgemm', 32, params) + for i in range(100): + if i == 10: + t0 = datetime.now() + pyclblast.gemm(queue, m, n, k, cla, clb, clc, a_ld=k, b_ld=n, c_ld=n) + print("#\tMWG = %-3d : %4d" % (mwg, int(2 * m * n * k / ((datetime.now() - t0).total_seconds() / 100) / 1024 ** 3))) + mwg *= 4 diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx index 9115240f..22cb680a 100644 --- a/src/pyclblast/src/pyclblast.pyx +++ b/src/pyclblast/src/pyclblast.pyx @@ -15,6 +15,8 @@ import numpy as np import pyopencl as cl from pyopencl.array import Array from libcpp cimport bool +from cpython.mem cimport PyMem_Malloc, PyMem_Free +from libc.string cimport strdup #################################################################################################### # CLBlast and OpenCL data-types @@ -2086,38 +2088,36 @@ def trsm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t # Overrides the parameters #################################################################################################### -from libc.stdlib cimport malloc, free -from libc.string cimport strdup - cdef extern from "clblast_c.h": ctypedef struct _cl_device_id: pass ctypedef _cl_device_id* cl_device_id - CLBlastStatusCode CLBlastOverrideParameters(const cl_device_id device, const char* kernel_name, const CLBlastPrecision precision, const size_t num_parameters, const char** parameters_names, const size_t* parameters_values) def override_parameters(device, kernel_name, precision, parameters): """ - precision = 16, 32, 64, 3232, 6464 - kernel name = unicode string - parameters = + Override the current parameters for the given kernel, on this device, with this precision. """ cdef cl_device_id device_id = device.int_ptr - cdef size_t n = len(parameters) - cdef const char **parameter_names = malloc(n * sizeof(char*)) - cdef size_t *parameter_values = malloc(n * sizeof(size_t)) - - # TODO: check mallocs + # read the parameters dictionary into names/values arrays, for use in CLBlastOverrideParameters + cdef size_t n = len(parameters) + cdef const char **parameter_names = PyMem_Malloc(n * sizeof(char*)) + cdef size_t *parameter_values = PyMem_Malloc(n * sizeof(size_t)) + if not (parameter_names or parameter_values): + raise MemoryError() for i, (k, v) in enumerate(parameters.items()): parameter_names[i] = strdup(k.encode('ascii')) parameter_values[i] = v + # call the underlying API err = CLBlastOverrideParameters(device_id, kernel_name.encode('ascii'), precision, n, parameter_names, parameter_values) if err != CLBlastSuccess: raise RuntimeError("PyCLBlast: 'OverrideParameters' failed: %s" % get_status_message(err)) - # TODO: free etc. + # tidy up: + PyMem_Free(parameter_names) + PyMem_Free(parameter_values) #################################################################################################### -- cgit v1.2.3