diff options
author | Gard Spreemann <gspreemann@gmail.com> | 2019-01-17 19:07:04 +0100 |
---|---|---|
committer | Gard Spreemann <gspreemann@gmail.com> | 2019-01-17 19:07:04 +0100 |
commit | 90bbc1585e637a6f98e41177d8e88ce575724968 (patch) | |
tree | 2244ca0bf8ffc43f4b98885c718579b950e48b46 | |
parent | 6a69bc65d6cf7942a0712e01a8200fa0f1d51595 (diff) |
This should be the correct way to do the (co)boundary module.
-rw-r--r-- | scnn/scnn.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py index a22aaeb..ed9ea2d 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -69,7 +69,22 @@ class Coboundary(nn.Module): assert(C_in == self.C_in) N = self.D.shape[0] - y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x)) + # This is essentially the equivalent of chebyshev.assemble for + # the convolutional modules. + X = [] + for b in range(0, B): + X12 = [] + for c_in in range(0, self.C_in): + X12.append(self.D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1 + X12 = torch.cat(X12, 0) + + assert(X12.shape == (self.C_in, N)) + X.append(X12.unsqueeze(0)) + + X = torch.cat(X, 0) + assert(X.shape == (B, self.C_in, N)) + + y = torch.einsum("oi,bin->bon", (self.theta, X)) assert(y.shape == (B, self.C_out, N)) return y |