diff options
Diffstat (limited to 'scripts/generator/generator.py')
-rw-r--r-- | scripts/generator/generator.py | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index 7bb66749..6726adda 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -68,6 +68,7 @@ ald_transa_m_k = "When `transpose_a == Transpose::kNo`, then `a_ld` must be at l ald_trans_n_k = "When `transpose == Transpose::kNo`, then `a_ld` must be at least `n`, otherwise `a_ld` must be at least `k`." ald_side_m_n = "When `side = Side::kLeft` then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `n`." bld_m = "The value of `b_ld` must be at least `m`." +bld_n = "The value of `b_ld` must be at least `n`." bld_transb_k_n = "When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`." bld_trans_n_k = "When `transpose == Transpose::kNo`, then `b_ld` must be at least `n`, otherwise `b_ld` must be at least `k`." cld_m = "The value of `c_ld` must be at least `m`." @@ -134,6 +135,9 @@ routines = [ Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]), Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]), Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []), +], +[ # Level X: extra routines (not part of BLAS) + Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]), ]] # ================================================================================================== @@ -148,6 +152,7 @@ def PrecisionToFullName(x): }[x] # ================================================================================================== + # Separators for the BLAS levels separators = [""" // ================================================================================================= @@ -160,8 +165,15 @@ separators = [""" """ // ================================================================================================= // BLAS level-3 (matrix-matrix) routines +// =================================================================================================""", +""" +// ================================================================================================= +// Extra non-BLAS routines (level-X) // ================================================================================================="""] +# Names of the level sub-folders +levelnames = ["1", "2", "3", "x"] + # Main header/footer for source files header = """ // ================================================================================================= @@ -373,7 +385,7 @@ files = [ path_clblast+"/test/wrapper_clblas.h", path_clblast+"/test/wrapper_cblas.h", ] -header_lines = [84, 71, 93, 22, 29, 41] +header_lines = [84, 74, 93, 22, 29, 41] footer_lines = [17, 71, 19, 14, 6, 6] # Checks whether the command-line arguments are valid; exists otherwise @@ -396,7 +408,8 @@ for i in xrange(0,len(files)): # Re-writes the body of the file with open(files[i], "w") as f: body = "" - for level in [1,2,3]: + levels = [1,2,3] if (i == 4 or i == 5) else [1,2,3,4] + for level in levels: body += separators[level-1]+"\n" if i == 0: body += clblast_h(routines[level-1]) @@ -417,14 +430,14 @@ for i in xrange(0,len(files)): # ================================================================================================== # Outputs all the correctness-test implementations -for level in [1,2,3]: +for level in [1,2,3,4]: for routine in routines[level-1]: if routine.has_tests: - filename = path_clblast+"/test/correctness/routines/level"+str(level)+"/x"+routine.name+".cc" + filename = path_clblast+"/test/correctness/routines/level"+levelnames[level-1]+"/x"+routine.name+".cc" with open(filename, "w") as f: body = "" body += "#include \"correctness/testblas.h\"\n" - body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n" + body += "#include \"routines/level"+levelnames[level-1]+"/x"+routine.name+".h\"\n\n" body += "// Shortcuts to the clblast namespace\n" body += "using float2 = clblast::float2;\n" body += "using double2 = clblast::double2;\n\n" @@ -443,14 +456,14 @@ for level in [1,2,3]: f.write(footer) # Outputs all the performance-test implementations -for level in [1,2,3]: +for level in [1,2,3,4]: for routine in routines[level-1]: if routine.has_tests: - filename = path_clblast+"/test/performance/routines/level"+str(level)+"/x"+routine.name+".cc" + filename = path_clblast+"/test/performance/routines/level"+levelnames[level-1]+"/x"+routine.name+".cc" with open(filename, "w") as f: body = "" body += "#include \"performance/client.h\"\n" - body += "#include \"routines/level"+str(level)+"/x"+routine.name+".h\"\n\n" + body += "#include \"routines/level"+levelnames[level-1]+"/x"+routine.name+".h\"\n\n" body += "// Shortcuts to the clblast namespace\n" body += "using float2 = clblast::float2;\n" body += "using double2 = clblast::double2;\n\n" @@ -487,7 +500,7 @@ with open(filename, "w") as f: f.write("\n\n") # Loops over the routines - for level in [1,2,3]: + for level in [1,2,3,4]: for routine in routines[level-1]: if routine.implemented: |