summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspreemann@gmail.com>2019-01-17 19:07:04 +0100
committerGard Spreemann <gspreemann@gmail.com>2019-01-17 19:07:04 +0100
commit90bbc1585e637a6f98e41177d8e88ce575724968 (patch)
tree2244ca0bf8ffc43f4b98885c718579b950e48b46
parent6a69bc65d6cf7942a0712e01a8200fa0f1d51595 (diff)
This should be the correct way to do the (co)boundary module.
-rw-r--r--scnn/scnn.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index a22aaeb..ed9ea2d 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -69,7 +69,22 @@ class Coboundary(nn.Module):
assert(C_in == self.C_in)
N = self.D.shape[0]
- y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x))
+ # This is essentially the equivalent of chebyshev.assemble for
+ # the convolutional modules.
+ X = []
+ for b in range(0, B):
+ X12 = []
+ for c_in in range(0, self.C_in):
+ X12.append(self.D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1
+ X12 = torch.cat(X12, 0)
+
+ assert(X12.shape == (self.C_in, N))
+ X.append(X12.unsqueeze(0))
+
+ X = torch.cat(X, 0)
+ assert(X.shape == (B, self.C_in, N))
+
+ y = torch.einsum("oi,bin->bon", (self.theta, X))
assert(y.shape == (B, self.C_out, N))
return y