From 7a3e06da1f629d1186f1f66dc3055f95a121ae40 Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Sat, 4 May 2019 15:24:37 +0200 Subject: Make the Laplacian changeable with each forward call. --- scnn/scnn.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/scnn/scnn.py b/scnn/scnn.py index 91f9706..f46cc56 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -12,21 +12,17 @@ def coo2tensor(A): return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) class SimplicialConvolution(nn.Module): - def __init__(self, K, L, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1): + def __init__(self, K, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1): assert groups == 1, "Only groups = 1 is currently supported." super().__init__() - assert(len(L.shape) == 2) - assert(L.shape[0] == L.shape[1]) assert(C_in > 0) assert(C_out > 0) assert(K > 0) - self.M = L.shape[0] self.C_in = C_in self.C_out = C_out self.K = K - self.L = L # Don't modify afterwards! self.enable_bias = enable_bias self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K))) @@ -35,13 +31,16 @@ class SimplicialConvolution(nn.Module): else: self.bias = 0.0 - def forward(self, x): + def forward(self, L, x): + assert(len(L.shape) == 2) + assert(L.shape[0] == L.shape[1]) + (B, C_in, M) = x.shape - - assert(M == self.M) + + assert(M == L.shape[0]) assert(C_in == self.C_in) - X = scnn.chebyshev.assemble(self.K, self.L, x) + X = scnn.chebyshev.assemble(self.K, L, x) y = torch.einsum("bimk,oik->bom", (X, self.theta)) assert(y.shape == (B, self.C_out, M)) @@ -55,16 +54,14 @@ class SimplicialConvolution(nn.Module): # Note: You can use this for a adjoints of coboundaries too. Just feed # a transposed D. class Coboundary(nn.Module): - def __init__(self, D, C_in, C_out, enable_bias = True, variance = 1.0): + def __init__(self, C_in, C_out, enable_bias = True, variance = 1.0): super().__init__() - assert(len(D.shape) == 2) assert(C_in > 0) assert(C_out > 0) self.C_in = C_in self.C_out = C_out - self.D = D # Don't modify. self.enable_bias = enable_bias self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in))) @@ -73,11 +70,15 @@ class Coboundary(nn.Module): else: self.bias = 0.0 - def forward(self, x): + def forward(self, D, x): + assert(len(D.shape) == 2) + (B, C_in, M) = x.shape - assert(self.D.shape[1] == M) + + assert(D.shape[1] == M) assert(C_in == self.C_in) - N = self.D.shape[0] + + N = D.shape[0] # This is essentially the equivalent of chebyshev.assemble for # the convolutional modules. @@ -85,7 +86,7 @@ class Coboundary(nn.Module): for b in range(0, B): X12 = [] for c_in in range(0, self.C_in): - X12.append(self.D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1 + X12.append(D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1 X12 = torch.cat(X12, 0) assert(X12.shape == (self.C_in, N)) -- cgit v1.2.3