From 564e87132b1558a1f9dcda80981479f04cc7bc64 Mon Sep 17 00:00:00 2001 From: Gard Spreemann Date: Tue, 16 Apr 2019 12:38:00 +0200 Subject: Allow custom variance. --- scnn/scnn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scnn/scnn.py b/scnn/scnn.py index ed9ea2d..12e220c 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -12,7 +12,7 @@ def coo2tensor(A): return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) class SimplicialConvolution(nn.Module): - def __init__(self, K, L, C_in, C_out, groups = 1): + def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1): assert groups == 1, "Only groups = 1 is currently supported." super().__init__() @@ -28,7 +28,7 @@ class SimplicialConvolution(nn.Module): self.K = K self.L = L # Don't modify afterwards! - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K))) def forward(self, x): (B, C_in, M) = x.shape @@ -50,7 +50,7 @@ class SimplicialConvolution(nn.Module): # Note: You can use this for a adjoints of coboundaries too. Just feed # a transposed D. class Coboundary(nn.Module): - def __init__(self, D, C_in, C_out): + def __init__(self, D, C_in, C_out, variance = 1.0): super().__init__() assert(len(D.shape) == 2) @@ -61,7 +61,7 @@ class Coboundary(nn.Module): self.C_out = C_out self.D = D # Don't modify. - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in))) def forward(self, x): (B, C_in, M) = x.shape -- cgit v1.2.3