diff options
Diffstat (limited to 'scripts/generator/generator/routine.py')
-rw-r--r-- | scripts/generator/generator/routine.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 7321349d..3b5a6b76 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -205,7 +205,7 @@ class Routine: def no_scalars(self): """Determines whether or not this routine has scalar arguments (alpha/beta)""" - return self.scalars == [] or self.name in ["im2col", "convgemm"] + return self.scalars == [] or self.name in ["im2col", "col2im", "convgemm"] def has_layout(self): """Determines whether the layout is an argument""" @@ -226,12 +226,14 @@ class Routine: """Determines which buffers go first (between alpha and beta) and which ones go after""" if self.level == "2b" or self.name == "had": return ["x", "y"] - return ["ap", "a", "b", "x", "im", "kernel"] + extra_buffer = "col" if self.name == "col2im" else "im" + return ["ap", "a", "b", "x", extra_buffer, "kernel"] def buffers_second(self): if self.level == "2b" or self.name == "had": return ["z", "ap", "a", "b", "c"] - return ["y", "c", "col", "result"] + extra_buffer = "im" if self.name == "col2im" else "col" + return ["y", "c", extra_buffer, "result"] def buffer(self, name): """Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')""" |