summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r--scripts/generator/generator/routine.py50
1 files changed, 33 insertions, 17 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index dd3c2ecb..f7c2a701 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -12,12 +12,12 @@ import generator.convert as convert
class Routine:
"""Class holding routine-specific information (e.g. name, which arguments, which precisions)"""
- def __init__(self, implemented, has_tests, batched, temp_buffer, level, name, template, flavours, sizes, options,
+ def __init__(self, implemented, has_tests, batched_strided, temp_buffer, level, name, template, flavours, sizes, options,
inputs, outputs, buffer_sizes, scalars, scratch,
description, details, requirements):
self.implemented = implemented
self.has_tests = has_tests
- self.batched = batched
+ self.batched = batched_strided
self.temp_buffer = temp_buffer
self.level = level
self.name = name
@@ -35,38 +35,42 @@ class Routine:
self.requirements = requirements
def lowercase_name(self):
- postfix = "batched" if self.batched else ""
+ postfix = "strided" if self.batched == 2 else ""
+ postfix += "batched" if self.batched != 0 else ""
return self.name + postfix
def plain_name(self):
- postfix = "Batched" if self.batched else ""
+ postfix = "Strided" if self.batched == 2 else ""
+ postfix += "Batched" if self.batched != 0 else ""
return self.name + postfix
def capitalized_name(self):
- postfix = "Batched" if self.batched else ""
+ postfix = "Strided" if self.batched == 2 else ""
+ postfix += "Batched" if self.batched != 0 else ""
return self.name.capitalize() + postfix
def upper_name(self):
- postfix = "BATCHED" if self.batched else ""
+ postfix = "STRIDED" if self.batched == 2 else ""
+ postfix += "BATCHED" if self.batched != 0 else ""
return self.name.upper() + postfix
def b_star(self):
- return "*" if self.batched else ""
+ return "*" if self.batched == 1 else ""
def b_s(self):
- return "s" if self.batched else ""
+ return "s" if self.batched == 1 else ""
def batch_count_def(self):
- return ["const size_t batch_count"] if self.batched else []
+ return ["const size_t batch_count"] if self.batched != 0 else []
def batch_count_list(self):
- return ["batch_count"] if self.batched else []
+ return ["batch_count"] if self.batched != 0 else []
def batch_count_type(self):
- return ["const size_t"] if self.batched else []
+ return ["const size_t"] if self.batched != 0 else []
def batch_count_doc(self):
- return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched else []
+ return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched != 0 else []
def batched_transform_to_cpp(self):
result = []
@@ -230,6 +234,8 @@ class Routine:
a = [name + "_buffer"]
b = [name + "_offset" + self.b_s()]
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += [name + "_stride"]
return [", ".join(a + b + c)]
return []
@@ -239,6 +245,8 @@ class Routine:
a = [name + "_buffer_bis"]
b = [name + "_offset"]
c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ if self.batched == 2:
+ c += [name + "_stride"]
return [", ".join(a + b + c)]
return []
@@ -258,6 +266,8 @@ class Routine:
a = [prefix + "cl_mem " + name + "_buffer"]
b = ["const size_t " + self.b_star() + name + "_offset" + self.b_s()]
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
+ if self.batched == 2:
+ c += ["const size_t " + name + "_stride"]
return [", ".join(a + b + c)]
return []
@@ -307,8 +317,10 @@ class Routine:
if name in self.inputs or name in self.outputs:
buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type
a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"]
- b = [name + "_offsets_cpp"] if self.batched else [name + "_offset"]
+ b = [name + "_offsets_cpp"] if self.batched == 1 else [name + "_offset"]
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += [name + "_stride"]
return [", ".join(a + b + c)]
return []
@@ -375,6 +387,8 @@ class Routine:
a = [prefix + "cl_mem"]
b = ["const size_t" + self.b_star()]
c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else []
+ if self.batched == 2:
+ c += ["const size_t"]
return [", ".join(a + b + c)]
return []
@@ -391,13 +405,15 @@ class Routine:
if name not in self.buffers_without_ld_inc():
c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " +
inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."]
+ if self.batched == 2:
+ c += ["`const size_t " + name + "_stride`: The (fixed) stride between two batches of the " + name.upper() + " matrix."]
return a + b + c
return []
def scalar(self, name):
"""Retrieves the name of a scalar (alpha/beta)"""
if name in self.scalars:
- if self.batched:
+ if self.batched == 1:
return [name + "s_cpp"]
return [name]
return []
@@ -418,11 +434,11 @@ class Routine:
"""Retrieves the use of a scalar (alpha/beta)"""
if name in self.scalars:
if name == "alpha":
- if self.batched:
+ if self.batched == 1:
return ["alphas_cpp.data()"]
return [flavour.use_alpha()]
elif name == "beta":
- if self.batched:
+ if self.batched == 1:
return ["betas_cpp.data()"]
return [flavour.use_beta()]
return [name]
@@ -866,7 +882,7 @@ class Routine:
if self.name in self.routines_scalar_no_return():
routine_name += "_sub"
indent += " "
- if self.batched:
+ if self.batched != 0:
routine_name += "batched"
result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "("
result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")"