diff options
author | Gard Spreemann <gspr@nonempty.org> | 2019-04-16 12:38:00 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2019-04-16 12:38:00 +0200 |
commit | 564e87132b1558a1f9dcda80981479f04cc7bc64 (patch) | |
tree | fd3d15a7d1cc377fcac5e14740eb17cd8b3fa897 | |
parent | 90bbc1585e637a6f98e41177d8e88ce575724968 (diff) |
Allow custom variance.
-rw-r--r-- | scnn/scnn.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/scnn/scnn.py b/scnn/scnn.py index ed9ea2d..12e220c 100644 --- a/scnn/scnn.py +++ b/scnn/scnn.py @@ -12,7 +12,7 @@ def coo2tensor(A): return torch.sparse_coo_tensor(idxs, vals, size = A.shape, requires_grad = False) class SimplicialConvolution(nn.Module): - def __init__(self, K, L, C_in, C_out, groups = 1): + def __init__(self, K, L, C_in, C_out, variance = 1.0, groups = 1): assert groups == 1, "Only groups = 1 is currently supported." super().__init__() @@ -28,7 +28,7 @@ class SimplicialConvolution(nn.Module): self.K = K self.L = L # Don't modify afterwards! - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in, self.K))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in, self.K))) def forward(self, x): (B, C_in, M) = x.shape @@ -50,7 +50,7 @@ class SimplicialConvolution(nn.Module): # Note: You can use this for a adjoints of coboundaries too. Just feed # a transposed D. class Coboundary(nn.Module): - def __init__(self, D, C_in, C_out): + def __init__(self, D, C_in, C_out, variance = 1.0): super().__init__() assert(len(D.shape) == 2) @@ -61,7 +61,7 @@ class Coboundary(nn.Module): self.C_out = C_out self.D = D # Don't modify. - self.theta = nn.parameter.Parameter(torch.randn((self.C_out, self.C_in))) + self.theta = nn.parameter.Parameter(variance*torch.randn((self.C_out, self.C_in))) def forward(self, x): (B, C_in, M) = x.shape |