From 6fb3c9fdf866105482447190a11619721e4c6b34 Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Wed, 24 Oct 2018 13:08:33 +0200 Subject: Initial commit of cleaned-up code. --- scnn/__init__.py | 0 scnn/chebyshev.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ scnn/scnn.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 scnn/__init__.py create mode 100644 scnn/chebyshev.py create mode 100644 scnn/scnn.py diff --git a/scnn/__init__.py b/scnn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scnn/chebyshev.py b/scnn/chebyshev.py new file mode 100644 index 0000000..8b5e086 --- /dev/null +++ b/scnn/chebyshev.py @@ -0,0 +1,51 @@ +import torch +import scipy.sparse as sp +import scipy.sparse.linalg as spl +import numpy as np + +def normalize(L, half_interval = False): + assert(sp.isspmatrix(L)) + M = L.shape[0] + assert(M == L.shape[1]) + topeig = spl.eigsh(L, k=1, which="LM", return_eigenvectors = False)[0] + #print("Topeig = %f" %(topeig)) + + ret = L.copy() + if half_interval: + ret *= 1.0/topeig + else: + ret *= 2.0/topeig + ret.setdiag(ret.diagonal(0) - np.ones(M), 0) + + return ret + +def assemble(K, L, x): + (B, C_in, M) = x.shape + assert(L.shape[0] == M) + assert(L.shape[0] == L.shape[1]) + assert(K > 0) + + X = [] + for b in range(0, B): + X123 = [] + for c_in in range(0, C_in): + X23 = [] + X23.append(x[b, c_in, :].unsqueeze(1)) # Constant, k = 0 term. + + if K > 1: + X23.append(L.mm(X23[0])) + for k in range(2, K): + X23.append(2*(L.mm(X23[k-1])) - X23[k-2]) + + X23 = torch.cat(X23, 1) + assert(X23.shape == (M, K)) + X123.append(X23.unsqueeze(0)) + + X123 = torch.cat(X123, 0) + assert(X123.shape == (C_in, M, K)) + X.append(X123.unsqueeze(0)) + + X = torch.cat(X, 0) + assert(X.shape == (B, C_in, M, K)) + + return X diff --git a/scnn/scnn.py b/scnn/scnn.py new file mode 100644 index 0000000..49efe71 --- /dev/null +++ b/scnn/scnn.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import numpy as np +import scipy.sparse as sp + +import chebyshev + +def coo2tensor(A): + assert(sp.isspmatrix_coo(A)) + idxs = torch.LongTensor(np.vstack((A.row, A.col))) + vals = torch.FloatTensor(A.data) + return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) + +class SimplicialConvolution(nn.Module): + def __init__(self, K, L, C_in, C_out, groups = 1): + assert(groups == 1, "Only groups = 1 is currently supported.") + super().__init__() + + assert(len(L.shape) == 2) + assert(L.shape[0] == L.shape[1]) + assert(C_in > 0) + assert(C_out > 0) + assert(K > 0) + + self.M = L.shape[0] + self.C_in = C_in + self.C_out = C_out + self.K = K + self.L = L # Don't modify afterwards! + + self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K))) + + def forward(self, x): + (B, C_in, M) = x.shape + assert(M == self.M) + assert(C_in == self.C_in) + + X = chebyshev.assemble(self.K, self.L, x) + y = torch.einsum("bimk,oik->bom", (X, self.theta)) + assert(y.shape == (B, self.C_out, M)) + + return y -- cgit v1.2.3