summaryrefslogtreecommitdiff
path: root/scnn/scnn.py
blob: f825cb90b7c4596a56a7c56776314385b6353d9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp

import scnn.chebyshev

def coo2tensor(A):
    assert(sp.isspmatrix_coo(A))
    idxs = torch.LongTensor(np.vstack((A.row, A.col)))
    vals = torch.FloatTensor(A.data)
    return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)

class SimplicialConvolution(nn.Module):
    def __init__(self, K, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1):
        assert groups == 1, "Only groups = 1 is currently supported."
        super().__init__()

        assert(C_in > 0)
        assert(C_out > 0)
        assert(K > 0)
        
        self.C_in = C_in
        self.C_out = C_out
        self.K = K
        self.enable_bias = enable_bias
        self.variance = variance

        self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K)))
        if self.enable_bias:
            self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1)))
        else:
            self.bias = 0.0
            
    def forward(self, L, x):
        assert(len(L.shape) == 2)
        assert(L.shape[0] == L.shape[1])
                
        (B, C_in, M) = x.shape
     
        assert(M == L.shape[0])
        assert(C_in == self.C_in)

        X = scnn.chebyshev.assemble(self.K, L, x)
        y = torch.einsum("bimk,oik->bom", (X, self.theta))
        assert(y.shape == (B, self.C_out, M))

        return y + self.bias

    def __repr__(self):
        return "SimplicialConvolution(K=%d, C_in=%d, C_out=%d, enable_bias=%s, variance=%f)" %(self.K, self.C_in, self.C_out, self.enable_bias, self.variance)

# This class does not yet implement the
# Laplacian-power-pre/post-composed with the coboundary. It can be
# simulated by just adding more layers anyway, so keeping it simple
# for now.
#
# Note: You can use this for a adjoints of coboundaries too. Just feed
# a transposed D.
class Coboundary(nn.Module):
    def __init__(self, C_in, C_out, enable_bias = True, variance = 1.0):
        super().__init__()

        assert(C_in > 0)
        assert(C_out > 0)

        self.C_in = C_in
        self.C_out = C_out
        self.enable_bias = enable_bias

        self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in)))
        if self.enable_bias:
            self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1)))
        else:
            self.bias = 0.0

    def forward(self, D, x):
        assert(len(D.shape) == 2)
        
        (B, C_in, M) = x.shape
        
        assert(D.shape[1] == M)
        assert(C_in == self.C_in)
        
        N = D.shape[0]

        # This is essentially the equivalent of chebyshev.assemble for
        # the convolutional modules.
        X = []
        for b in range(0, B):
            X12 = []
            for c_in in range(0, self.C_in):
                X12.append(D.mm(x[b, c_in, :].unsqueeze(1)).transpose(0,1)) # D.mm(x[b, c_in, :]) has shape Nx1
            X12 = torch.cat(X12, 0)

            assert(X12.shape == (self.C_in, N))
            X.append(X12.unsqueeze(0))

        X = torch.cat(X, 0)
        assert(X.shape == (B, self.C_in, N))
                   
        y = torch.einsum("oi,bin->bon", (self.theta, X))
        assert(y.shape == (B, self.C_out, N))

        return y + self.bias