From 712a580d2c40717d378ec36ced4ce04df541400a Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Wed, 1 May 2019 17:54:15 +0200 Subject: Introduce bias ffs. --- scnn/scnn.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/scnn/scnn.py b/scnn/scnn.py index 12e220c..91f9706 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -12,7 +12,7 @@ def coo2tensor(A): return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) class SimplicialConvolution(nn.Module): - def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1): + def __init__(self, K, L, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1): assert groups == 1, "Only groups = 1 is currently supported." super().__init__() @@ -27,9 +27,14 @@ class SimplicialConvolution(nn.Module): self.C_out = C_out self.K = K self.L = L # Don't modify afterwards! + self.enable_bias = enable_bias self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K))) - + if self.enable_bias: + self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1))) + else: + self.bias = 0.0 + def forward(self, x): (B, C_in, M) = x.shape @@ -40,7 +45,7 @@ class SimplicialConvolution(nn.Module): y = torch.einsum("bimk,oik->bom", (X, self.theta)) assert(y.shape == (B, self.C_out, M)) - return y + return y + self.bias # This class does not yet implement the # Laplacian-power-pre/post-composed with the coboundary. It can be @@ -50,7 +55,7 @@ class SimplicialConvolution(nn.Module): # Note: You can use this for a adjoints of coboundaries too. Just feed # a transposed D. class Coboundary(nn.Module): - def __init__(self, D, C_in, C_out, variance = 1.0): + def __init__(self, D, C_in, C_out, enable_bias = True, variance = 1.0): super().__init__() assert(len(D.shape) == 2) @@ -60,8 +65,13 @@ class Coboundary(nn.Module): self.C_in = C_in self.C_out = C_out self.D = D # Don't modify. + self.enable_bias = enable_bias self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in))) + if self.enable_bias: + self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1))) + else: + self.bias = 0.0 def forward(self, x): (B, C_in, M) = x.shape @@ -87,4 +97,4 @@ class Coboundary(nn.Module): y = torch.einsum("oi,bin->bon", (self.theta, X)) assert(y.shape == (B, self.C_out, N)) - return y + return y + self.bias -- cgit v1.2.3