summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2020-05-10 12:26:25 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2020-05-10 12:26:25 +0200
commitb94e81af10b8cb22ac338a2a8db349222ab1b57a (patch)
tree4909f1e0fe37f84381ff78c2e13bc56d25f0e03b /src
parent5f4b3ffcf7c8f90eee33a8504ede00ea52f79c0e (diff)
Added pyclblast bindings for the 3 batched routines
Diffstat (limited to 'src')
-rw-r--r--src/pyclblast/src/pyclblast.pyx257
1 files changed, 257 insertions, 0 deletions
diff --git a/src/pyclblast/src/pyclblast.pyx b/src/pyclblast/src/pyclblast.pyx
index 14efcf8a..eb46649f 100644
--- a/src/pyclblast/src/pyclblast.pyx
+++ b/src/pyclblast/src/pyclblast.pyx
@@ -364,6 +364,7 @@ def swap(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0):
err = CLBlastHswap(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXswap' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -405,6 +406,7 @@ def scal(queue, n, x, x_inc = 1, alpha = 1.0, x_offset = 0):
err = CLBlastHscal(n, <cl_half>alpha, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXscal' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -448,6 +450,7 @@ def copy(queue, n, x, y, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0):
err = CLBlastHcopy(n, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXcopy' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -491,6 +494,7 @@ def axpy(queue, n, x, y, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset = 0, y_offs
err = CLBlastHaxpy(n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXaxpy' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -530,6 +534,7 @@ def dot(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0, d
err = CLBlastHdot(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXdot' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -566,6 +571,7 @@ def dotu(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0,
err = CLBlastZdotu(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXdotu' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -602,6 +608,7 @@ def dotc(queue, n, x, y, dot, x_inc = 1, y_inc = 1, x_offset = 0, y_offset = 0,
err = CLBlastZdotc(n, dot_buffer, dot_offset, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXdotc' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -645,6 +652,7 @@ def nrm2(queue, n, x, nrm2, x_inc = 1, x_offset = 0, nrm2_offset = 0):
err = CLBlastHnrm2(n, nrm2_buffer, nrm2_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXnrm2' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -688,6 +696,7 @@ def asum(queue, n, x, asum, x_inc = 1, x_offset = 0, asum_offset = 0):
err = CLBlastHasum(n, asum_buffer, asum_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXasum' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -731,6 +740,7 @@ def sum(queue, n, x, sum, x_inc = 1, x_offset = 0, sum_offset = 0):
err = CLBlastHsum(n, sum_buffer, sum_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsum' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -774,6 +784,7 @@ def amax(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0):
err = CLBlastiHamax(n, imax_buffer, imax_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXamax' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -817,6 +828,7 @@ def amin(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0):
err = CLBlastiHamin(n, imin_buffer, imin_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXamin' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -860,6 +872,7 @@ def max(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0):
err = CLBlastiHmax(n, imax_buffer, imax_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXmax' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -903,6 +916,7 @@ def min(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0):
err = CLBlastiHmin(n, imin_buffer, imin_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXmin' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -949,6 +963,7 @@ def gemv(queue, m, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0
err = CLBlastHgemv(CLBlastLayoutRowMajor, a_transpose, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXgemv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -995,6 +1010,7 @@ def gbmv(queue, m, n, kl, ku, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0,
err = CLBlastHgbmv(CLBlastLayoutRowMajor, a_transpose, m, n, kl, ku, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXgbmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1032,6 +1048,7 @@ def hemv(queue, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.0,
err = CLBlastZhemv(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhemv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1069,6 +1086,7 @@ def hbmv(queue, n, k, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0
err = CLBlastZhbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhbmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1106,6 +1124,7 @@ def hpmv(queue, n, ap, x, y, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.
err = CLBlastZhpmv(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhpmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1146,6 +1165,7 @@ def symv(queue, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.0,
err = CLBlastHsymv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsymv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1186,6 +1206,7 @@ def sbmv(queue, n, k, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0
err = CLBlastHsbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsbmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1226,6 +1247,7 @@ def spmv(queue, n, ap, x, y, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.
err = CLBlastHspmv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXspmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1272,6 +1294,7 @@ def trmv(queue, n, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp = Fal
err = CLBlastHtrmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtrmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1318,6 +1341,7 @@ def tbmv(queue, n, k, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp =
err = CLBlastHtbmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, k, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtbmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1364,6 +1388,7 @@ def tpmv(queue, n, ap, x, ap_ld, x_inc = 1, lower_triangle = False, a_transp = F
err = CLBlastHtpmv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtpmv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1407,6 +1432,7 @@ def trsv(queue, n, a, x, a_ld, x_inc = 1, lower_triangle = False, a_transp = Fal
err = CLBlastZtrsv(CLBlastLayoutRowMajor, triangle, a_transpose, diagonal, n, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtrsv' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1446,6 +1472,7 @@ def ger(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset
err = CLBlastHger(CLBlastLayoutRowMajor, m, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXger' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1482,6 +1509,7 @@ def geru(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset
err = CLBlastZgeru(CLBlastLayoutRowMajor, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXgeru' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1518,6 +1546,7 @@ def gerc(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset
err = CLBlastZgerc(CLBlastLayoutRowMajor, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXgerc' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1553,6 +1582,7 @@ def her(queue, n, x, a, a_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, x_
err = CLBlastZher(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXher' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1588,6 +1618,7 @@ def hpr(queue, n, x, ap, ap_ld, x_inc = 1, alpha = 1.0, lower_triangle = False,
err = CLBlastZhpr(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhpr' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1625,6 +1656,7 @@ def her2(queue, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_trian
err = CLBlastZher2(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXher2' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1662,6 +1694,7 @@ def hpr2(queue, n, x, y, ap, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_tri
err = CLBlastZhpr2(CLBlastLayoutRowMajor, triangle, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhpr2' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1700,6 +1733,7 @@ def syr(queue, n, x, a, a_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, x_
err = CLBlastHsyr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsyr' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1738,6 +1772,7 @@ def spr(queue, n, x, ap, ap_ld, x_inc = 1, alpha = 1.0, lower_triangle = False,
err = CLBlastHspr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXspr' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1778,6 +1813,7 @@ def syr2(queue, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_trian
err = CLBlastHsyr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsyr2' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1818,6 +1854,7 @@ def spr2(queue, n, x, y, ap, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_tri
err = CLBlastHspr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXspr2' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1865,6 +1902,7 @@ def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_t
err = CLBlastHgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXgemm' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1912,6 +1950,7 @@ def symm(queue, m, n, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, right_
err = CLBlastHsymm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsymm' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1950,6 +1989,7 @@ def hemm(queue, m, n, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, right_
err = CLBlastZhemm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXhemm' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -1995,6 +2035,7 @@ def syrk(queue, n, k, a, c, a_ld, c_ld, alpha = 1.0, beta = 0.0, lower_triangle
err = CLBlastHsyrk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsyrk' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -2031,6 +2072,7 @@ def herk(queue, n, k, a, c, a_ld, c_ld, alpha = 1.0, beta = 0.0, lower_triangle
err = CLBlastZherk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXherk' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -2078,6 +2120,7 @@ def syr2k(queue, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, lower
err = CLBlastHsyr2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXsyr2k' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -2116,6 +2159,7 @@ def her2k(queue, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, lower
err = CLBlastZher2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXher2k' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -2163,6 +2207,7 @@ def trmm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t
err = CLBlastHtrmm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtrmm' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
@@ -2207,11 +2252,223 @@ def trsm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t
err = CLBlastZtrsm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
if err != CLBlastSuccess:
raise RuntimeError("PyCLBlast: 'CLBlastXtrsm' failed: %s" % get_status_message(err))
return cl.Event.from_int_ptr(<size_t>event)
####################################################################################################
+# Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
+####################################################################################################
+
+cdef extern from "clblast_c.h":
+ CLBlastStatusCode CLBlastSaxpyBatched(const size_t n, const float *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDaxpyBatched(const size_t n, const double *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCaxpyBatched(const size_t n, const cl_float2 *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZaxpyBatched(const size_t n, const cl_double2 *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n, const cl_half *alphas, const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc, cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+
+def axpyBatched(queue, n, x, y, alphas, x_offsets, y_offsets, x_inc = 1, y_inc = 1):
+ """
+ xAXPYBATCHED: Batched version of AXPY
+ """
+
+ dtype = check_dtype([x, y], ["float32", "float64", "complex64", "complex128", "float16"])
+ check_vector(x, "x")
+ check_vector(y, "y")
+
+ if len(x_offsets) != len(y_offsets) != len(alphas):
+ raise RuntimeError("PyCLBlast: 'CLBlastXaxpyBatched' failed: length of batch-sized arguments x_offsets, y_offsets, alphas should be equal")
+ batch_count = len(x_offsets)
+
+ cdef size_t *x_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))
+ for i in range(batch_count):
+ x_offsets_c[i] = x_offsets[i]
+ cdef size_t *y_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))
+ for i in range(batch_count):
+ y_offsets_c[i] = y_offsets[i]
+ cdef void *alphas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))
+ for i in range(batch_count):
+ if dtype == np.dtype("float32"):
+ (<cl_float*>alphas_c)[i] = <cl_float>alphas[i]
+ elif dtype == np.dtype("float64"):
+ (<cl_double*>alphas_c)[i] = <cl_double>alphas[i]
+ elif dtype == np.dtype("complex64"):
+ (<cl_float2*>alphas_c)[i] = <cl_float2>cl_float2(x=alphas[i].real,y=alphas[i].imag)
+ elif dtype == np.dtype("complex128"):
+ (<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag)
+ elif dtype == np.dtype("float16"):
+ (<cl_half*>alphas_c)[i] = <cl_half>alphas[i]
+
+ cdef cl_mem x_buffer = <cl_mem><size_t>x.base_data.int_ptr
+ cdef cl_mem y_buffer = <cl_mem><size_t>y.base_data.int_ptr
+
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_event event = NULL
+
+ cdef CLBlastStatusCode err
+ if dtype == np.dtype("float32"):
+ err = CLBlastSaxpyBatched(n, <cl_float*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float64"):
+ err = CLBlastDaxpyBatched(n, <cl_double*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex64"):
+ err = CLBlastCaxpyBatched(n, <cl_float2*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex128"):
+ err = CLBlastZaxpyBatched(n, <cl_double2*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float16"):
+ err = CLBlastHaxpyBatched(n, <cl_half*>alphas_c, x_buffer, x_offsets_c, x_inc, y_buffer, y_offsets_c, y_inc, batch_count, &command_queue, &event)
+ else:
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
+ PyMem_Free(x_offsets_c)
+ PyMem_Free(y_offsets_c)
+ PyMem_Free(alphas_c)
+
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'CLBlastXaxpyBatched' failed: %s" % get_status_message(err))
+ return cl.Event.from_int_ptr(<size_t>event)
+
+####################################################################################################
+# Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
+####################################################################################################
+
+cdef extern from "clblast_c.h":
+ CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const float *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const double *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_float2 *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_double2 *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_half *alphas, const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld, const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld, const cl_half *betas, cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+
+def gemmBatched(queue, m, n, k, a, b, c, alphas, betas, a_ld, b_ld, c_ld, a_offsets, b_offsets, c_offsets, a_transp = False, b_transp = False):
+ """
+ xGEMMBATCHED: Batched version of GEMM
+ """
+
+ dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128", "float16"])
+ check_matrix(a, "a")
+ check_matrix(b, "b")
+ check_matrix(c, "c")
+
+ if len(a_offsets) != len(b_offsets) != len(c_offsets) != len(alphas) != len(betas):
+ raise RuntimeError("PyCLBlast: 'CLBlastXgemmBatched' failed: length of batch-sized arguments a_offsets, b_offsets, c_offsets, alphas, betas should be equal")
+ batch_count = len(a_offsets)
+
+ cdef size_t *a_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))
+ for i in range(batch_count):
+ a_offsets_c[i] = a_offsets[i]
+ cdef size_t *b_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))
+ for i in range(batch_count):
+ b_offsets_c[i] = b_offsets[i]
+ cdef size_t *c_offsets_c = <size_t *> PyMem_Malloc(batch_count * sizeof(size_t))
+ for i in range(batch_count):
+ c_offsets_c[i] = c_offsets[i]
+ cdef void *alphas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))
+ for i in range(batch_count):
+ if dtype == np.dtype("float32"):
+ (<cl_float*>alphas_c)[i] = <cl_float>alphas[i]
+ elif dtype == np.dtype("float64"):
+ (<cl_double*>alphas_c)[i] = <cl_double>alphas[i]
+ elif dtype == np.dtype("complex64"):
+ (<cl_float2*>alphas_c)[i] = <cl_float2>cl_float2(x=alphas[i].real,y=alphas[i].imag)
+ elif dtype == np.dtype("complex128"):
+ (<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag)
+ elif dtype == np.dtype("float16"):
+ (<cl_half*>alphas_c)[i] = <cl_half>alphas[i]
+ cdef void *betas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))
+ for i in range(batch_count):
+ if dtype == np.dtype("float32"):
+ (<cl_float*>betas_c)[i] = <cl_float>betas[i]
+ elif dtype == np.dtype("float64"):
+ (<cl_double*>betas_c)[i] = <cl_double>betas[i]
+ elif dtype == np.dtype("complex64"):
+ (<cl_float2*>betas_c)[i] = <cl_float2>cl_float2(x=betas[i].real,y=betas[i].imag)
+ elif dtype == np.dtype("complex128"):
+ (<cl_double2*>betas_c)[i] = <cl_double2>cl_double2(x=betas[i].real,y=betas[i].imag)
+ elif dtype == np.dtype("float16"):
+ (<cl_half*>betas_c)[i] = <cl_half>betas[i]
+
+ cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr
+ cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr
+ cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr
+
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_event event = NULL
+ a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo
+ b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo
+
+ cdef CLBlastStatusCode err
+ if dtype == np.dtype("float32"):
+ err = CLBlastSgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_float*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float64"):
+ err = CLBlastDgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_double*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex64"):
+ err = CLBlastCgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_float2*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex128"):
+ err = CLBlastZgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_double2*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float16"):
+ err = CLBlastHgemmBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half*>alphas_c, a_buffer, a_offsets_c, a_ld, b_buffer, b_offsets_c, b_ld, <cl_half*>betas_c, c_buffer, c_offsets_c, c_ld, batch_count, &command_queue, &event)
+ else:
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
+ PyMem_Free(a_offsets_c)
+ PyMem_Free(b_offsets_c)
+ PyMem_Free(c_offsets_c)
+ PyMem_Free(alphas_c)
+ PyMem_Free(betas_c)
+
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'CLBlastXgemmBatched' failed: %s" % get_status_message(err))
+ return cl.Event.from_int_ptr(<size_t>event)
+
+####################################################################################################
+# StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
+####################################################################################################
+
+cdef extern from "clblast_c.h":
+ CLBlastStatusCode CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const float alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const float beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const double alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const double beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_float2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_float2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_double2 alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_double2 beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+ CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose, const size_t m, const size_t n, const size_t k, const cl_half alpha, const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride, const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const cl_half beta, cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride, const size_t batch_count,cl_command_queue* queue, cl_event* event)
+
+def gemmStridedBatched(queue, m, n, k, batch_count, a, b, c, a_ld, b_ld, c_ld, a_stride, b_stride, c_stride, alpha = 1.0, beta = 0.0, a_transp = False, b_transp = False, a_offset = 0, b_offset = 0, c_offset = 0):
+ """
+ xGEMMSTRIDEDBATCHED: StridedBatched version of GEMM
+ """
+
+ dtype = check_dtype([a, b, c], ["float32", "float64", "complex64", "complex128", "float16"])
+ check_matrix(a, "a")
+ check_matrix(b, "b")
+ check_matrix(c, "c")
+
+ cdef cl_mem a_buffer = <cl_mem><size_t>a.base_data.int_ptr
+ cdef cl_mem b_buffer = <cl_mem><size_t>b.base_data.int_ptr
+ cdef cl_mem c_buffer = <cl_mem><size_t>c.base_data.int_ptr
+
+ cdef cl_command_queue command_queue = <cl_command_queue><size_t>queue.int_ptr
+ cdef cl_event event = NULL
+ a_transpose = CLBlastTransposeYes if a_transp else CLBlastTransposeNo
+ b_transpose = CLBlastTransposeYes if b_transp else CLBlastTransposeNo
+
+ cdef CLBlastStatusCode err
+ if dtype == np.dtype("float32"):
+ err = CLBlastSgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_float>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float64"):
+ err = CLBlastDgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_double>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex64"):
+ err = CLBlastCgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_float2>cl_float2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_float2>cl_float2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("complex128"):
+ err = CLBlastZgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
+ elif dtype == np.dtype("float16"):
+ err = CLBlastHgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_half>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
+ else:
+ raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)
+
+ if err != CLBlastSuccess:
+ raise RuntimeError("PyCLBlast: 'CLBlastXgemmStridedBatched' failed: %s" % get_status_message(err))
+ return cl.Event.from_int_ptr(<size_t>event)
+
+####################################################################################################
# Overrides the parameters
####################################################################################################