summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspreemann@gmail.com>2019-01-15 19:51:12 +0100
committerGard Spreemann <gspreemann@gmail.com>2019-01-15 19:51:12 +0100
commit7ef7516343db3678391939399f065565a6a4dcc5 (patch)
treed3543865608822fb3d9675d91808075529b65dcb
parent4a3b1f8951388086b2c28d9a17426ce4085480cd (diff)
Typo fixes.
-rw-r--r--scnn/scnn.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index 2eae3f3..f81ca19 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -40,3 +40,34 @@ class SimplicialConvolution(nn.Module):
assert(y.shape == (B, self.C_out, M))
return y
+
+# This class does not yet implement the
+# Laplacian-power-pre/post-composed with the coboundary. It can be
+# simulated by just adding more layers anyway, so keeping it simple
+# for now.
+#
+# Note: You can use this for a adjoints of coboundaries too. Just feed
+# a transposed D.
+class Coboundary(nn.Module):
+ def __init__(self, D, C_in, C_out):
+ super().__init__()
+
+ assert(len(D.shape) == 2)
+ assert(C_in > 0)
+ assert(C_out > 0)
+
+ self.C_in = C_in
+ self.C_out = C_out
+ self.D = D # Don't modify.
+
+ self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in)))
+
+ def forward(self, x):
+ (B, C_in, M) = x.shape
+ assert(self.D.shape[1] == M)
+ assert(C_in == self.C_in)
+
+ y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x))
+ assert(y.shape == (B, self.C_out, N))
+
+ return y