diff options
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r-- | scripts/generator/generator/routine.py | 50 |
1 files changed, 33 insertions, 17 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index dd3c2ecb..f7c2a701 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -12,12 +12,12 @@ import generator.convert as convert class Routine: """Class holding routine-specific information (e.g. name, which arguments, which precisions)""" - def __init__(self, implemented, has_tests, batched, temp_buffer, level, name, template, flavours, sizes, options, + def __init__(self, implemented, has_tests, batched_strided, temp_buffer, level, name, template, flavours, sizes, options, inputs, outputs, buffer_sizes, scalars, scratch, description, details, requirements): self.implemented = implemented self.has_tests = has_tests - self.batched = batched + self.batched = batched_strided self.temp_buffer = temp_buffer self.level = level self.name = name @@ -35,38 +35,42 @@ class Routine: self.requirements = requirements def lowercase_name(self): - postfix = "batched" if self.batched else "" + postfix = "strided" if self.batched == 2 else "" + postfix += "batched" if self.batched != 0 else "" return self.name + postfix def plain_name(self): - postfix = "Batched" if self.batched else "" + postfix = "Strided" if self.batched == 2 else "" + postfix += "Batched" if self.batched != 0 else "" return self.name + postfix def capitalized_name(self): - postfix = "Batched" if self.batched else "" + postfix = "Strided" if self.batched == 2 else "" + postfix += "Batched" if self.batched != 0 else "" return self.name.capitalize() + postfix def upper_name(self): - postfix = "BATCHED" if self.batched else "" + postfix = "STRIDED" if self.batched == 2 else "" + postfix += "BATCHED" if self.batched != 0 else "" return self.name.upper() + postfix def b_star(self): - return "*" if self.batched else "" + return "*" if self.batched == 1 else "" def b_s(self): - return "s" if self.batched else "" + return "s" if self.batched == 1 else "" def batch_count_def(self): - return ["const size_t batch_count"] if self.batched else [] + return ["const size_t batch_count"] if self.batched != 0 else [] def batch_count_list(self): - return ["batch_count"] if self.batched else [] + return ["batch_count"] if self.batched != 0 else [] def batch_count_type(self): - return ["const size_t"] if self.batched else [] + return ["const size_t"] if self.batched != 0 else [] def batch_count_doc(self): - return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched else [] + return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched != 0 else [] def batched_transform_to_cpp(self): result = [] @@ -230,6 +234,8 @@ class Routine: a = [name + "_buffer"] b = [name + "_offset" + self.b_s()] c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else [] + if self.batched == 2: + c += [name + "_stride"] return [", ".join(a + b + c)] return [] @@ -239,6 +245,8 @@ class Routine: a = [name + "_buffer_bis"] b = [name + "_offset"] c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] + if self.batched == 2: + c += [name + "_stride"] return [", ".join(a + b + c)] return [] @@ -258,6 +266,8 @@ class Routine: a = [prefix + "cl_mem " + name + "_buffer"] b = ["const size_t " + self.b_star() + name + "_offset" + self.b_s()] c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] + if self.batched == 2: + c += ["const size_t " + name + "_stride"] return [", ".join(a + b + c)] return [] @@ -307,8 +317,10 @@ class Routine: if name in self.inputs or name in self.outputs: buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"] - b = [name + "_offsets_cpp"] if self.batched else [name + "_offset"] + b = [name + "_offsets_cpp"] if self.batched == 1 else [name + "_offset"] c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else [] + if self.batched == 2: + c += [name + "_stride"] return [", ".join(a + b + c)] return [] @@ -375,6 +387,8 @@ class Routine: a = [prefix + "cl_mem"] b = ["const size_t" + self.b_star()] c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else [] + if self.batched == 2: + c += ["const size_t"] return [", ".join(a + b + c)] return [] @@ -391,13 +405,15 @@ class Routine: if name not in self.buffers_without_ld_inc(): c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " + inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."] + if self.batched == 2: + c += ["`const size_t " + name + "_stride`: The (fixed) stride between two batches of the " + name.upper() + " matrix."] return a + b + c return [] def scalar(self, name): """Retrieves the name of a scalar (alpha/beta)""" if name in self.scalars: - if self.batched: + if self.batched == 1: return [name + "s_cpp"] return [name] return [] @@ -418,11 +434,11 @@ class Routine: """Retrieves the use of a scalar (alpha/beta)""" if name in self.scalars: if name == "alpha": - if self.batched: + if self.batched == 1: return ["alphas_cpp.data()"] return [flavour.use_alpha()] elif name == "beta": - if self.batched: + if self.batched == 1: return ["betas_cpp.data()"] return [flavour.use_beta()] return [name] @@ -866,7 +882,7 @@ class Routine: if self.name in self.routines_scalar_no_return(): routine_name += "_sub" indent += " " - if self.batched: + if self.batched != 0: routine_name += "batched" result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "(" result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")" |