Skip to content

Commit

Permalink
add back local agg
Browse files Browse the repository at this point in the history
  • Loading branch information
ines-chami committed Oct 3, 2020
1 parent dc48b94 commit a526385
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'alpha': (0.2, 'alpha for leakyrelu in graph attention networks'),
'double-precision': ('0', 'whether to use double precision'),
'use-att': (0, 'whether to use hyperbolic attention or not'),
'local-agg': (0, 'whether to local tangent space aggregation or not')
},
'data_config': {
'dataset': ('cora', 'which dataset to use'),
Expand Down
22 changes: 17 additions & 5 deletions layers/hyp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ class HyperbolicGraphConvolution(nn.Module):
Hyperbolic graph convolution layer.
"""

def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att):
def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att, local_agg):
super(HyperbolicGraphConvolution, self).__init__()
self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att)
self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att, local_agg)
self.hyp_act = HypAct(manifold, c_in, c_out, act)

def forward(self, input):
Expand Down Expand Up @@ -119,22 +119,34 @@ class HypAgg(Module):
Hyperbolic aggregation layer.
"""

def __init__(self, manifold, c, in_features, dropout, use_att):
def __init__(self, manifold, c, in_features, dropout, use_att, local_agg):
super(HypAgg, self).__init__()
self.manifold = manifold
self.c = c

self.in_features = in_features
self.dropout = dropout
self.local_agg = local_agg
self.use_att = use_att
if self.use_att:
self.att = DenseAtt(in_features, dropout)

def forward(self, x, adj):
x_tangent = self.manifold.logmap0(x, c=self.c)
if self.use_att:
adj_att = self.att(x_tangent, adj)
support_t = torch.matmul(adj_att, x_tangent)
if self.local_agg:
x_local_tangent = []
for i in range(x.size(0)):
x_local_tangent.append(self.manifold.logmap(x[i], x, c=self.c))
x_local_tangent = torch.stack(x_local_tangent, dim=0)
adj_att = self.att(x_tangent, adj)
att_rep = adj_att.unsqueeze(-1) * x_local_tangent
support_t = torch.sum(adj_att.unsqueeze(-1) * x_local_tangent, dim=1)
output = self.manifold.proj(self.manifold.expmap(x, support_t, c=self.c), c=self.c)
return output
else:
adj_att = self.att(x_tangent, adj)
support_t = torch.matmul(adj_att, x_tangent)
else:
support_t = torch.spmm(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
Expand Down
2 changes: 1 addition & 1 deletion models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, c, args):
act = acts[i]
hgc_layers.append(
hyp_layers.HyperbolicGraphConvolution(
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att, args.local_agg
)
)
self.layers = nn.Sequential(*hgc_layers)
Expand Down

0 comments on commit a526385

Please sign in to comment.