summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2019-05-01 17:54:15 +0200
committerGard Spreemann <gspr@nonempty.org>2019-05-01 17:54:15 +0200
commit712a580d2c40717d378ec36ced4ce04df541400a (patch)
tree43492beba982a0c51c973558d5023734eb7c18ec
parentf7e7045f634cdac651da6968e9cf45749a966b01 (diff)
Introduce bias ffs.
-rw-r--r--scnn/scnn.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index 12e220c..91f9706 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -12,7 +12,7 @@ def coo2tensor(A):
return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)
class SimplicialConvolution(nn.Module):
- def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1):
+ def __init__(self, K, L, C_in, C_out, enable_bias = True, variance = 1.0, groups = 1):
assert groups == 1, "Only groups = 1 is currently supported."
super().__init__()
@@ -27,9 +27,14 @@ class SimplicialConvolution(nn.Module):
self.C_out = C_out
self.K = K
self.L = L # Don't modify afterwards!
+ self.enable_bias = enable_bias
self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K)))
-
+ if self.enable_bias:
+ self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1)))
+ else:
+ self.bias = 0.0
+
def forward(self, x):
(B, C_in, M) = x.shape
@@ -40,7 +45,7 @@ class SimplicialConvolution(nn.Module):
y = torch.einsum("bimk,oik->bom", (X, self.theta))
assert(y.shape == (B, self.C_out, M))
- return y
+ return y + self.bias
# This class does not yet implement the
# Laplacian-power-pre/post-composed with the coboundary. It can be
@@ -50,7 +55,7 @@ class SimplicialConvolution(nn.Module):
# Note: You can use this for a adjoints of coboundaries too. Just feed
# a transposed D.
class Coboundary(nn.Module):
- def __init__(self, D, C_in, C_out, variance = 1.0):
+ def __init__(self, D, C_in, C_out, enable_bias = True, variance = 1.0):
super().__init__()
assert(len(D.shape) == 2)
@@ -60,8 +65,13 @@ class Coboundary(nn.Module):
self.C_in = C_in
self.C_out = C_out
self.D = D # Don't modify.
+ self.enable_bias = enable_bias
self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in)))
+ if self.enable_bias:
+ self.bias = nn.parameter.Parameter(torch.zeros((1, self.C_out, 1)))
+ else:
+ self.bias = 0.0
def forward(self, x):
(B, C_in, M) = x.shape
@@ -87,4 +97,4 @@ class Coboundary(nn.Module):
y = torch.einsum("oi,bin->bon", (self.theta, X))
assert(y.shape == (B, self.C_out, N))
- return y
+ return y + self.bias