summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2019-05-04 15:24:37 +0200
committerGard Spreemann <gspr@nonempty.org>2019-05-04 15:24:37 +0200
commit7a3e06da1f629d1186f1f66dc3055f95a121ae40 (patch)
tree014b8f291e281922360b37077774297e4b265cd1
parentef2e9d1567906507a04c181fcbbb19bda3ca5639 (diff)
Make the Laplacian changeable with each forward call.
-rw-r--r--scnn/scnn.py33
1 files changed, 17 insertions, 16 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index 91f9706..f46cc56 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -12,21 +12,17 @@ def coo2tensor(A):
return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)
class SimplicialConvolution(nn.Module):
- def __init__(self, K, L, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1):
+ def __init__(self, K, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1):
assert groups == 1, "Only groups = 1 is currently supported."
super().__init__()
- assert(len(L.shape) == 2)
- assert(L.shape[0] == L.shape[1])
assert(C_in > 0)
assert(C_out > 0)
assert(K > 0)
- self.M = L.shape[0]
self.C_in = C_in
self.C_out = C_out
self.K = K
- self.L = L # Don't modify afterwards!
self.enable_bias = enable_bias
self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K)))
@@ -35,13 +31,16 @@ class SimplicialConvolution(nn.Module):
else:
self.bias = 0.0
- def forward(self, x):
+ def forward(self, L, x):
+ assert(len(L.shape) == 2)
+ assert(L.shape[0] == L.shape[1])
+
(B, C_in, M) = x.shape
-
- assert(M == self.M)
+
+ assert(M == L.shape[0])
assert(C_in == self.C_in)
- X = scnn.chebyshev.assemble(self.K, self.L, x)
+ X = scnn.chebyshev.assemble(self.K, L, x)
y = torch.einsum("bimk,oik->bom", (X, self.theta))
assert(y.shape == (B, self.C_out, M))
@@ -55,16 +54,14 @@ class SimplicialConvolution(nn.Module):
# Note: You can use this for a adjoints of coboundaries too. Just feed
# a transposed D.
class Coboundary(nn.Module):
- def __init__(self, D, C_in, C_out, enable_bias = True, variance = 1.0):
+ def __init__(self, C_in, C_out, enable_bias = True, variance = 1.0):
super().__init__()
- assert(len(D.shape) == 2)
assert(C_in > 0)
assert(C_out > 0)
self.C_in = C_in
self.C_out = C_out
- self.D = D # Don't modify.
self.enable_bias = enable_bias
self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in)))
@@ -73,11 +70,15 @@ class Coboundary(nn.Module):
else:
self.bias = 0.0
- def forward(self, x):
+ def forward(self, D, x):
+ assert(len(D.shape) == 2)
+
(B, C_in, M) = x.shape
- assert(self.D.shape[1] == M)
+
+ assert(D.shape[1] == M)
assert(C_in == self.C_in)
- N = self.D.shape[0]
+
+ N = D.shape[0]
# This is essentially the equivalent of chebyshev.assemble for
# the convolutional modules.
@@ -85,7 +86,7 @@ class Coboundary(nn.Module):
for b in range(0, B):
X12 = []
for c_in in range(0, self.C_in):
- X12.append(self.D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1
+ X12.append(D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1
X12 = torch.cat(X12, 0)
assert(X12.shape == (self.C_in, N))