summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--scnn/scnn.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index f81ca19..a22aaeb 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -32,6 +32,7 @@ class SimplicialConvolution(nn.Module):
def forward(self, x):
(B, C_in, M) = x.shape
+
assert(M == self.M)
assert(C_in == self.C_in)
@@ -66,6 +67,7 @@ class Coboundary(nn.Module):
(B, C_in, M) = x.shape
assert(self.D.shape[1] == M)
assert(C_in == self.C_in)
+ N = self.D.shape[0]
y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x))
assert(y.shape == (B, self.C_out, N))