-
Notifications
You must be signed in to change notification settings - Fork 4
/
layers.py
129 lines (95 loc) · 5.5 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
import torch.nn.functional as F
################################
### GAT LAYER ###
################################
class GraphAttentionLayer(nn.Module):
"""
Graph Attention Layer (GAT) as described in the paper `"Graph Attention Networks" <https://arxiv.org/pdf/1710.10903.pdf>`.
This operation can be mathematically described as:
e_ij = a(W h_i, W h_j)
α_ij = softmax_j(e_ij) = exp(e_ij) / Σ_k(exp(e_ik))
h_i' = σ(Σ_j(α_ij W h_j))
where h_i and h_j are the feature vectors of nodes i and j respectively, W is a learnable weight matrix,
a is an attention mechanism that computes the attention coefficients e_ij, and σ is an activation function.
"""
def __init__(self, in_features: int, out_features: int, n_heads: int, concat: bool = False, dropout: float = 0.4, leaky_relu_slope: float = 0.2):
super(GraphAttentionLayer, self).__init__()
self.n_heads = n_heads # Number of attention heads
self.concat = concat # wether to concatenate the final attention heads
self.dropout = dropout # Dropout rate
if concat: # concatenating the attention heads
self.out_features = out_features # Number of output features per node
assert out_features % n_heads == 0 # Ensure that out_features is a multiple of n_heads
self.n_hidden = out_features // n_heads
else: # averaging output over the attention heads (Used in the main paper)
self.n_hidden = out_features
# A shared linear transformation, parametrized by a weight matrix W is applied to every node
# Initialize the weight matrix W
self.W = nn.Parameter(torch.empty(size=(in_features, self.n_hidden * n_heads)))
# Initialize the attention weights a
self.a = nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1)))
self.leakyrelu = nn.LeakyReLU(leaky_relu_slope) # LeakyReLU activation function
self.softmax = nn.Softmax(dim=1) # softmax activation function to the attention coefficients
self.reset_parameters() # Reset the parameters
def reset_parameters(self):
"""
Reinitialize learnable parameters.
"""
nn.init.xavier_normal_(self.W)
nn.init.xavier_normal_(self.a)
def _get_attention_scores(self, h_transformed: torch.Tensor):
"""calculates the attention scores e_ij for all pairs of nodes (i, j) in the graph
in vectorized parallel form. for each pair of source and target nodes (i, j),
the attention score e_ij is computed as follows:
e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
where || denotes the concatenation operation, and a and W are the learnable parameters.
Args:
h_transformed (torch.Tensor): Transformed feature matrix with shape (n_nodes, n_heads, n_hidden),
where n_nodes is the number of nodes and out_features is the number of output features per node.
Returns:
torch.Tensor: Attention score matrix with shape (n_heads, n_nodes, n_nodes), where n_nodes is the number of nodes.
"""
source_scores = torch.matmul(h_transformed, self.a[:, :self.n_hidden, :])
target_scores = torch.matmul(h_transformed, self.a[:, self.n_hidden:, :])
# broadcast add
# (n_heads, n_nodes, 1) + (n_heads, 1, n_nodes) = (n_heads, n_nodes, n_nodes)
e = source_scores + target_scores.mT
return self.leakyrelu(e)
def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
"""
Performs a graph attention layer operation.
Args:
h (torch.Tensor): Input tensor representing node features.
adj_mat (torch.Tensor): Adjacency matrix representing graph structure.
Returns:
torch.Tensor: Output tensor after the graph convolution operation.
"""
n_nodes = h.shape[0]
# Apply linear transformation to node feature -> W h
# output shape (n_nodes, n_hidden * n_heads)
h_transformed = torch.mm(h, self.W)
h_transformed = F.dropout(h_transformed, self.dropout, training=self.training)
# splitting the heads by reshaping the tensor and putting heads dim first
# output shape (n_heads, n_nodes, n_hidden)
h_transformed = h_transformed.view(n_nodes, self.n_heads, self.n_hidden).permute(1, 0, 2)
# getting the attention scores
# output shape (n_heads, n_nodes, n_nodes)
e = self._get_attention_scores(h_transformed)
# Set the attention score for non-existent edges to -9e15 (MASKING NON-EXISTENT EDGES)
connectivity_mask = -9e16 * torch.ones_like(e)
e = torch.where(adj_mat > 0, e, connectivity_mask) # masked attention scores
# attention coefficients are computed as a softmax over the rows
# for each column j in the attention score matrix e
attention = F.softmax(e, dim=-1)
attention = F.dropout(attention, self.dropout, training=self.training)
# final node embeddings are computed as a weighted average of the features of its neighbors
h_prime = torch.matmul(attention, h_transformed)
# concatenating/averaging the attention heads
# output shape (n_nodes, out_features)
if self.concat:
h_prime = h_prime.permute(1, 0, 2).contiguous().view(n_nodes, self.out_features)
else:
h_prime = h_prime.mean(dim=0)
return h_prime