summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspreemann@gmail.com>2018-10-24 13:08:33 +0200
committerGard Spreemann <gspreemann@gmail.com>2018-10-24 13:08:33 +0200
commit6fb3c9fdf866105482447190a11619721e4c6b34 (patch)
treeb081b6d24445aceae135edb5381d9502b53a6303
Initial commit of cleaned-up code.
-rw-r--r--scnn/__init__.py0
-rw-r--r--scnn/chebyshev.py51
-rw-r--r--scnn/scnn.py42
3 files changed, 93 insertions, 0 deletions
diff --git a/scnn/__init__.py b/scnn/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/scnn/__init__.py
diff --git a/scnn/chebyshev.py b/scnn/chebyshev.py
new file mode 100644
index 0000000..8b5e086
--- /dev/null
+++ b/scnn/chebyshev.py
@@ -0,0 +1,51 @@
+import torch
+import scipy.sparse as sp
+import scipy.sparse.linalg as spl
+import numpy as np
+
+def normalize(L, half_interval = False):
+ assert(sp.isspmatrix(L))
+ M = L.shape[0]
+ assert(M == L.shape[1])
+ topeig = spl.eigsh(L, k=1, which="LM", return_eigenvectors = False)[0]
+ #print("Topeig = %f" %(topeig))
+
+ ret = L.copy()
+ if half_interval:
+ ret *= 1.0/topeig
+ else:
+ ret *= 2.0/topeig
+ ret.setdiag(ret.diagonal(0) - np.ones(M), 0)
+
+ return ret
+
+def assemble(K, L, x):
+ (B, C_in, M) = x.shape
+ assert(L.shape[0] == M)
+ assert(L.shape[0] == L.shape[1])
+ assert(K > 0)
+
+ X = []
+ for b in range(0, B):
+ X123 = []
+ for c_in in range(0, C_in):
+ X23 = []
+ X23.append(x[b, c_in, :].unsqueeze(1)) # Constant, k = 0 term.
+
+ if K > 1:
+ X23.append(L.mm(X23[0]))
+ for k in range(2, K):
+ X23.append(2*(L.mm(X23[k-1])) - X23[k-2])
+
+ X23 = torch.cat(X23, 1)
+ assert(X23.shape == (M, K))
+ X123.append(X23.unsqueeze(0))
+
+ X123 = torch.cat(X123, 0)
+ assert(X123.shape == (C_in, M, K))
+ X.append(X123.unsqueeze(0))
+
+ X = torch.cat(X, 0)
+ assert(X.shape == (B, C_in, M, K))
+
+ return X
diff --git a/scnn/scnn.py b/scnn/scnn.py
new file mode 100644
index 0000000..49efe71
--- /dev/null
+++ b/scnn/scnn.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import scipy.sparse as sp
+
+import chebyshev
+
+def coo2tensor(A):
+ assert(sp.isspmatrix_coo(A))
+ idxs = torch.LongTensor(np.vstack((A.row, A.col)))
+ vals = torch.FloatTensor(A.data)
+ return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)
+
+class SimplicialConvolution(nn.Module):
+ def __init__(self, K, L, C_in, C_out, groups = 1):
+ assert(groups == 1, "Only groups = 1 is currently supported.")
+ super().__init__()
+
+ assert(len(L.shape) == 2)
+ assert(L.shape[0] == L.shape[1])
+ assert(C_in > 0)
+ assert(C_out > 0)
+ assert(K > 0)
+
+ self.M = L.shape[0]
+ self.C_in = C_in
+ self.C_out = C_out
+ self.K = K
+ self.L = L # Don't modify afterwards!
+
+ self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K)))
+
+ def forward(self, x):
+ (B, C_in, M) = x.shape
+ assert(M == self.M)
+ assert(C_in == self.C_in)
+
+ X = chebyshev.assemble(self.K, self.L, x)
+ y = torch.einsum("bimk,oik->bom", (X, self.theta))
+ assert(y.shape == (B, self.C_out, M))
+
+ return y