From 6a69bc65d6cf7942a0712e01a8200fa0f1d51595 Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Thu, 17 Jan 2019 16:11:29 +0100 Subject: Fix. --- scnn/scnn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scnn/scnn.py b/scnn/scnn.py index f81ca19..a22aaeb 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -32,6 +32,7 @@ class SimplicialConvolution(nn.Module): def forward(self, x): (B, C_in, M) = x.shape + assert(M == self.M) assert(C_in == self.C_in) @@ -66,6 +67,7 @@ class Coboundary(nn.Module): (B, C_in, M) = x.shape assert(self.D.shape[1] == M) assert(C_in == self.C_in) + N = self.D.shape[0] y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x)) assert(y.shape == (B, self.C_out, N)) -- cgit v1.2.3