diff options
Diffstat (limited to 'scnn')
-rw-r--r-- | scnn/scnn.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py index ed9ea2d..12e220c 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -12,7 +12,7 @@ def coo2tensor(A): return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) class SimplicialConvolution(nn.Module): - def __init__(self, K, L, C_in, C_out, groups = 1): + def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1): assert groups == 1, "Only groups = 1 is currently supported." super().__init__() @@ -28,7 +28,7 @@ class SimplicialConvolution(nn.Module): self.K = K self.L = L # Don't modify afterwards! - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K))) def forward(self, x): (B, C_in, M) = x.shape @@ -50,7 +50,7 @@ class SimplicialConvolution(nn.Module): # Note: You can use this for a adjoints of coboundaries too. Just feed # a transposed D. class Coboundary(nn.Module): - def __init__(self, D, C_in, C_out): + def __init__(self, D, C_in, C_out, variance = 1.0): super().__init__() assert(len(D.shape) == 2) @@ -61,7 +61,7 @@ class Coboundary(nn.Module): self.C_out = C_out self.D = D # Don't modify. - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in))) def forward(self, x): (B, C_in, M) = x.shape |