From fc737e745ec96019647dbd0ec6d1f5d1f0024213 Mon Sep 17 00:00:00 2001 From: peach-lucien Date: Fri, 10 Nov 2023 18:06:54 +0100 Subject: [PATCH] updated to norm and skip layers --- MARBLE/layers.py | 29 +++++++++++++++++++++++++++++ MARBLE/main.py | 22 ++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/MARBLE/layers.py b/MARBLE/layers.py index 06da2f2f..447e709c 100644 --- a/MARBLE/layers.py +++ b/MARBLE/layers.py @@ -6,6 +6,35 @@ from MARBLE import geometry as g +class SkipMLP(nn.Module): + """ MLP with skip connections """ + + def __init__(self, channel_list, dropout=0.0, bias=True): + super(SkipMLP, self).__init__() + assert len(channel_list) > 1, "Channel list must have at least two elements for an MLP." + self.layers = nn.ModuleList() + self.dropout = dropout + self.in_channels = channel_list[0] + for i in range(len(channel_list) - 1): + self.layers.append(nn.Linear(channel_list[i], channel_list[i+1], bias=bias)) + if i < len(channel_list) - 2: # Don't add activation or dropout to the last layer + self.layers.append(nn.ReLU(inplace=True)) + if dropout > 0: + self.layers.append(nn.Dropout(dropout)) + + def forward(self, x): + identity = x + for layer in self.layers: + if isinstance(layer, nn.Linear): + if x.shape[1] == layer.weight.shape[1]: # Check if skip connection is possible + identity = x # Save identity for skip connection + x = layer(x) + if x.shape[1] == identity.shape[1]: # Apply skip connection if shapes match + x += identity + else: + x = layer(x) # Apply activation or dropout + return x + class Diffusion(nn.Module): """Diffusion with learned t.""" diff --git a/MARBLE/main.py b/MARBLE/main.py index d612f673..b3545814 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -145,6 +145,7 @@ def check_parameters(self, data): "bias", "batch_norm", "vec_norm", + "emb_norm", "seed", "n_sampled_nb", "processes", @@ -201,12 +202,22 @@ def setup_layers(self): + [self.params["out_channels"]] ) - self.enc = MLP( + # self.enc = MLP( + # channel_list=channel_list, + # dropout=self.params["dropout"], + # #norm=self.params["batch_norm"], + # bias=self.params["bias"], + # ) + + self.enc = layers.SkipMLP( channel_list=channel_list, dropout=self.params["dropout"], - norm=self.params["batch_norm"], + #norm=self.params["batch_norm"], bias=self.params["bias"], ) + + + def forward(self, data, n_id, adjs=None): """Forward pass. @@ -267,10 +278,13 @@ def forward(self, data, n_id, adjs=None): if self.params["include_positions"]: out = torch.hstack( [data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable - ) - + ) + emb = self.enc(out) + #if self.params['emb_norm']: + emb = F.normalize(emb) + return emb, mask[: size[1]] def evaluate(self, data):