summaryrefslogtreecommitdiff
path: root/scnn/scnn.py
blob: a22aaeba59e887ef131aae48b3d559bc1035af17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp

import scnn.chebyshev

def coo2tensor(A):
    assert(sp.isspmatrix_coo(A))
    idxs = torch.LongTensor(np.vstack((A.row, A.col)))
    vals = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)

class SimplicialConvolution(nn.Module):
    def __init__(self, K, L, C_in, C_out, groups = 1):
        assert groups == 1, "Only groups = 1 is currently supported."
        super().__init__()

        assert(len(L.shape) == 2)
        assert(L.shape[0] == L.shape[1])
        assert(C_in > 0)
        assert(C_out > 0)
        assert(K > 0)
        
        self.M = L.shape[0]
        self.C_in = C_in
        self.C_out = C_out
        self.K = K
        self.L = L # Don't modify afterwards!

        self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K)))
        
    def forward(self, x):
        (B, C_in, M) = x.shape
       
        assert(M == self.M)
        assert(C_in == self.C_in)

        X = scnn.chebyshev.assemble(self.K, self.L, x)
        y = torch.einsum("bimk,oik->bom", (X, self.theta))
        assert(y.shape == (B, self.C_out, M))

        return y

# This class does not yet implement the
# Laplacian-power-pre/post-composed with the coboundary. It can be
# simulated by just adding more layers anyway, so keeping it simple
# for now.
#
# Note: You can use this for a adjoints of coboundaries too. Just feed
# a transposed D.
class Coboundary(nn.Module):
    def __init__(self, D, C_in, C_out):
        super().__init__()

        assert(len(D.shape) == 2)
        assert(C_in > 0)
        assert(C_out > 0)

        self.C_in = C_in
        self.C_out = C_out
        self.D = D # Don't modify.

        self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in)))

    def forward(self, x):
        (B, C_in, M) = x.shape
        assert(self.D.shape[1] == M)
        assert(C_in == self.C_in)
        N = self.D.shape[0]

        y = torch.einsum("oi,nm,bim->bon", (self.theta, self.D, x))
        assert(y.shape == (B, self.C_out, N))

        return y