summaryrefslogtreecommitdiff
path: root/scripts/generator/generator/routine.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r--scripts/generator/generator/routine.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py
index 7321349d..3b5a6b76 100644
--- a/scripts/generator/generator/routine.py
+++ b/scripts/generator/generator/routine.py
@@ -205,7 +205,7 @@ class Routine:
def no_scalars(self):
"""Determines whether or not this routine has scalar arguments (alpha/beta)"""
- return self.scalars == [] or self.name in ["im2col", "convgemm"]
+ return self.scalars == [] or self.name in ["im2col", "col2im", "convgemm"]
def has_layout(self):
"""Determines whether the layout is an argument"""
@@ -226,12 +226,14 @@ class Routine:
"""Determines which buffers go first (between alpha and beta) and which ones go after"""
if self.level == "2b" or self.name == "had":
return ["x", "y"]
- return ["ap", "a", "b", "x", "im", "kernel"]
+ extra_buffer = "col" if self.name == "col2im" else "im"
+ return ["ap", "a", "b", "x", extra_buffer, "kernel"]
def buffers_second(self):
if self.level == "2b" or self.name == "had":
return ["z", "ap", "a", "b", "c"]
- return ["y", "c", "col", "result"]
+ extra_buffer = "im" if self.name == "col2im" else "col"
+ return ["y", "c", extra_buffer, "result"]
def buffer(self, name):
"""Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')"""