summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2019-04-16 12:38:00 +0200
committerGard Spreemann <gspr@nonempty.org>2019-04-16 12:38:00 +0200
commit564e87132b1558a1f9dcda80981479f04cc7bc64 (patch)
treefd3d15a7d1cc377fcac5e14740eb17cd8b3fa897
parent90bbc1585e637a6f98e41177d8e88ce575724968 (diff)
Allow custom variance.
-rw-r--r--scnn/scnn.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py
index ed9ea2d..12e220c 100644
--- a/scnn/scnn.py
+++ b/scnn/scnn.py
@@ -12,7 +12,7 @@ def coo2tensor(A):
return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False)
class SimplicialConvolution(nn.Module):
- def __init__(self, K, L, C_in, C_out, groups = 1):
+ def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1):
assert groups == 1, "Only groups = 1 is currently supported."
super().__init__()
@@ -28,7 +28,7 @@ class SimplicialConvolution(nn.Module):
self.K = K
self.L = L # Don't modify afterwards!
- self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K)))
+ self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K)))
def forward(self, x):
(B, C_in, M) = x.shape
@@ -50,7 +50,7 @@ class SimplicialConvolution(nn.Module):
# Note: You can use this for a adjoints of coboundaries too. Just feed
# a transposed D.
class Coboundary(nn.Module):
- def __init__(self, D, C_in, C_out):
+ def __init__(self, D, C_in, C_out, variance = 1.0):
super().__init__()
assert(len(D.shape) == 2)
@@ -61,7 +61,7 @@ class Coboundary(nn.Module):
self.C_out = C_out
self.D = D # Don't modify.
- self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in)))
+ self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in)))
def forward(self, x):
(B, C_in, M) = x.shape