Skip to content

Commit

Permalink
updated to norm and skip layers
Browse files Browse the repository at this point in the history
  • Loading branch information
peach-lucien committed Nov 10, 2023
1 parent 38cb48a commit fc737e7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
29 changes: 29 additions & 0 deletions MARBLE/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@

from MARBLE import geometry as g

class SkipMLP(nn.Module):
""" MLP with skip connections """

def __init__(self, channel_list, dropout=0.0, bias=True):
super(SkipMLP, self).__init__()
assert len(channel_list) > 1, "Channel list must have at least two elements for an MLP."
self.layers = nn.ModuleList()
self.dropout = dropout
self.in_channels = channel_list[0]
for i in range(len(channel_list) - 1):
self.layers.append(nn.Linear(channel_list[i], channel_list[i+1], bias=bias))
if i < len(channel_list) - 2: # Don't add activation or dropout to the last layer
self.layers.append(nn.ReLU(inplace=True))
if dropout > 0:
self.layers.append(nn.Dropout(dropout))

def forward(self, x):
identity = x
for layer in self.layers:
if isinstance(layer, nn.Linear):
if x.shape[1] == layer.weight.shape[1]: # Check if skip connection is possible
identity = x # Save identity for skip connection
x = layer(x)
if x.shape[1] == identity.shape[1]: # Apply skip connection if shapes match
x += identity
else:
x = layer(x) # Apply activation or dropout
return x


class Diffusion(nn.Module):
"""Diffusion with learned t."""
Expand Down
22 changes: 18 additions & 4 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def check_parameters(self, data):
"bias",
"batch_norm",
"vec_norm",
"emb_norm",
"seed",
"n_sampled_nb",
"processes",
Expand Down Expand Up @@ -201,12 +202,22 @@ def setup_layers(self):
+ [self.params["out_channels"]]
)

self.enc = MLP(
# self.enc = MLP(
# channel_list=channel_list,
# dropout=self.params["dropout"],
# #norm=self.params["batch_norm"],
# bias=self.params["bias"],
# )

self.enc = layers.SkipMLP(
channel_list=channel_list,
dropout=self.params["dropout"],
norm=self.params["batch_norm"],
#norm=self.params["batch_norm"],
bias=self.params["bias"],
)




def forward(self, data, n_id, adjs=None):
"""Forward pass.
Expand Down Expand Up @@ -267,10 +278,13 @@ def forward(self, data, n_id, adjs=None):
if self.params["include_positions"]:
out = torch.hstack(
[data.pos[n_id[: size[1]]], out] # pylint: disable=undefined-loop-variable
)
)

emb = self.enc(out)

#if self.params['emb_norm']:
emb = F.normalize(emb)

return emb, mask[: size[1]]

def evaluate(self, data):
Expand Down

0 comments on commit fc737e7

Please sign in to comment.