summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--scnn/scnn.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index a22aaeb..ed9ea2d 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -69,7 +69,22 @@ class Coboundary(nn.Module):
assert(C_in == self.C_in)
N = self.D.shape[0]
- y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x))
+ # This is essentially the equivalent of chebyshev.assemble for
+ # the convolutional modules.
+ X = []
+ for b in range(0, B):
+ X12 = []
+ for c_in in range(0, self.C_in):
+ X12.append(self.D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1
+ X12 = torch.cat(X12, 0)
+
+ assert(X12.shape == (self.C_in, N))
+ X.append(X12.unsqueeze(0))
+
+ X = torch.cat(X, 0)
+ assert(X.shape == (B, self.C_in, N))
+
+ y = torch.einsum("oi,bin->bon", (self.theta, X))
assert(y.shape == (B, self.C_out, N))
return y