summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorCedric Nugteren <web@cedricnugteren.nl>2016-06-16 18:07:46 +0200
committerCedric Nugteren <web@cedricnugteren.nl>2016-06-16 18:07:46 +0200
commit52ccaf5b25e14c9ce032315e5e96b1f27886d481 (patch)
tree087288b7aebf2a06ffc4e7dcbcd4353f7a3be6a7 /scripts
parent39b7dbc5e37829abfbcfb77852b9138b31540b42 (diff)
Added XOMATCOPY routines to perform out-of-place matrix scaling, copying, and/or transposing
Diffstat (limited to 'scripts')
-rw-r--r--scripts/generator/generator.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py
index 7bb66749..6726adda 100644
--- a/scripts/generator/generator.py
+++ b/scripts/generator/generator.py
@@ -68,6 +68,7 @@ ald_transa_m_k = "When `transpose_a == Transpose::kNo`, then `a_ld` must be at l
ald_trans_n_k = "When `transpose == Transpose::kNo`, then `a_ld` must be at least `n`, otherwise `a_ld` must be at least `k`."
ald_side_m_n = "When `side = Side::kLeft` then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `n`."
bld_m = "The value of `b_ld` must be at least `m`."
+bld_n = "The value of `b_ld` must be at least `n`."
bld_transb_k_n = "When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`."
bld_trans_n_k = "When `transpose == Transpose::kNo`, then `b_ld` must be at least `n`, otherwise `b_ld` must be at least `k`."
cld_m = "The value of `c_ld` must be at least `m`."
@@ -134,6 +135,9 @@ routines = [
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []),
+],
+[ # Level X: extra routines (not part of BLAS)
+ Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
]]
# ==================================================================================================
@@ -148,6 +152,7 @@ def PrecisionToFullName(x):
}[x]
# ==================================================================================================
+
# Separators for the BLAS levels
separators = ["""
// =================================================================================================
@@ -160,8 +165,15 @@ separators = ["""
"""
// =================================================================================================
// BLAS level-3 (matrix-matrix) routines
+// =================================================================================================""",
+"""
+// =================================================================================================
+// Extra non-BLAS routines (level-X)
// ================================================================================================="""]
+# Names of the level sub-folders
+levelnames = ["1", "2", "3", "x"]
+
# Main header/footer for source files
header = """
// =================================================================================================
@@ -373,7 +385,7 @@ files = [
path_clblast+"/test/wrapper_clblas.h",
path_clblast+"/test/wrapper_cblas.h",
]
-header_lines = [84, 71, 93, 22, 29, 41]
+header_lines = [84, 74, 93, 22, 29, 41]
footer_lines = [17, 71, 19, 14, 6, 6]
# Checks whether the command-line arguments are valid; exists otherwise
@@ -396,7 +408,8 @@ for i in xrange(0,len(files)):
# Re-writes the body of the file
with open(files[i], "w") as f:
body = ""
- for level in [1,2,3]:
+ levels = [1,2,3] if (i == 4 or i == 5) else [1,2,3,4]
+ for level in levels:
body += separators[level-1]+"\n"
if i == 0:
body += clblast_h(routines[level-1])
@@ -417,14 +430,14 @@ for i in xrange(0,len(files)):
# ==================================================================================================
# Outputs all the correctness-test implementations
-for level in [1,2,3]:
+for level in [1,2,3,4]:
for routine in routines[level-1]:
if routine.has_tests:
- filename = path_clblast+"/test/correctness/routines/level"+str(level)+"/x"+routine.name+".cc"
+ filename = path_clblast+"/test/correctness/routines/level"+levelnames[level-1]+"/x"+routine.name+".cc"
with open(filename, "w") as f:
body = ""
body += "#include \"correctness/testblas.h\"\n"
- body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "#include \"routines/level"+levelnames[level-1]+"/x"+routine.name+".h\"\n\n"
body += "// Shortcuts to the clblast namespace\n"
body += "using float2 = clblast::float2;\n"
body += "using double2 = clblast::double2;\n\n"
@@ -443,14 +456,14 @@ for level in [1,2,3]:
f.write(footer)
# Outputs all the performance-test implementations
-for level in [1,2,3]:
+for level in [1,2,3,4]:
for routine in routines[level-1]:
if routine.has_tests:
- filename = path_clblast+"/test/performance/routines/level"+str(level)+"/x"+routine.name+".cc"
+ filename = path_clblast+"/test/performance/routines/level"+levelnames[level-1]+"/x"+routine.name+".cc"
with open(filename, "w") as f:
body = ""
body += "#include \"performance/client.h\"\n"
- body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n"
+ body += "#include \"routines/level"+levelnames[level-1]+"/x"+routine.name+".h\"\n\n"
body += "// Shortcuts to the clblast namespace\n"
body += "using float2 = clblast::float2;\n"
body += "using double2 = clblast::double2;\n\n"
@@ -487,7 +500,7 @@ with open(filename, "w") as f:
f.write("\n\n")
# Loops over the routines
- for level in [1,2,3]:
+ for level in [1,2,3,4]:
for routine in routines[level-1]:
if routine.implemented: