-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodels.py
71 lines (57 loc) · 2.73 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch import nn
import torch.nn.functional as F
from layers import GraphAttentionLayer
################################
### GAT NETWORK MODULE ###
################################
class GAT(nn.Module):
"""
Graph Attention Network (GAT) as described in the paper `"Graph Attention Networks" <https://arxiv.org/pdf/1710.10903.pdf>`.
Consists of a 2-layer stack of Graph Attention Layers (GATs). The fist GAT Layer is followed by an ELU activation.
And the second (final) layer is a GAT layer with a single attention head and softmax activation function.
"""
def __init__(self,
in_features,
n_hidden,
n_heads,
num_classes,
concat=False,
dropout=0.4,
leaky_relu_slope=0.2):
""" Initializes the GAT model.
Args:
in_features (int): number of input features per node.
n_hidden (int): output size of the first Graph Attention Layer.
n_heads (int): number of attention heads in the first Graph Attention Layer.
num_classes (int): number of classes to predict for each node.
concat (bool, optional): Wether to concatinate attention heads or take an average over them for the
output of the first Graph Attention Layer. Defaults to False.
dropout (float, optional): dropout rate. Defaults to 0.4.
leaky_relu_slope (float, optional): alpha (slope) of the leaky relu activation. Defaults to 0.2.
"""
super(GAT, self).__init__()
# Define the Graph Attention layers
self.gat1 = GraphAttentionLayer(
in_features=in_features, out_features=n_hidden, n_heads=n_heads,
concat=concat, dropout=dropout, leaky_relu_slope=leaky_relu_slope
)
self.gat2 = GraphAttentionLayer(
in_features=n_hidden, out_features=num_classes, n_heads=1,
concat=False, dropout=dropout, leaky_relu_slope=leaky_relu_slope
)
def forward(self, input_tensor: torch.Tensor , adj_mat: torch.Tensor):
"""
Performs a forward pass through the network.
Args:
input_tensor (torch.Tensor): Input tensor representing node features.
adj_mat (torch.Tensor): Adjacency matrix representing graph structure.
Returns:
torch.Tensor: Output tensor after the forward pass.
"""
# Apply the first Graph Attention layer
x = self.gat1(input_tensor, adj_mat)
x = F.elu(x) # Apply ELU activation function to the output of the first layer
# Apply the second Graph Attention layer
x = self.gat2(x, adj_mat)
return F.log_softmax(x, dim=1) # Apply log softmax activation function