1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp
import scnn.chebyshev
def coo2tensor(A):
assert(sp.isspmatrix_coo(A))
idxs = torch.LongTensor(np.vstack((A.row, A.col)))
vals = torch.FloatTensor(A.data)
return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)
class SimplicialConvolution(nn.Module):
def __init__(self, K, L, C_in, C_out, groups = 1):
assert(groups == 1, "Only groups = 1 is currently supported.")
super().__init__()
assert(len(L.shape) == 2)
assert(L.shape[0] == L.shape[1])
assert(C_in > 0)
assert(C_out > 0)
assert(K > 0)
self.M = L.shape[0]
self.C_in = C_in
self.C_out = C_out
self.K = K
self.L = L # Don't modify afterwards!
self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K)))
def forward(self, x):
(B, C_in, M) = x.shape
assert(M == self.M)
assert(C_in == self.C_in)
X = chebyshev.assemble(self.K, self.L, x)
y = torch.einsum("bimk,oik->bom", (X, self.theta))
assert(y.shape == (B, self.C_out, M))
return y
|