diff --git a/config.py b/config.py index d45774d..08e99cc 100644 --- a/config.py +++ b/config.py @@ -41,6 +41,7 @@ 'alpha': (0.2, 'alpha for leakyrelu in graph attention networks'), 'double-precision': ('0', 'whether to use double precision'), 'use-att': (0, 'whether to use hyperbolic attention or not'), + 'local-agg': (0, 'whether to local tangent space aggregation or not') }, 'data_config': { 'dataset': ('cora', 'which dataset to use'), diff --git a/layers/hyp_layers.py b/layers/hyp_layers.py index 45e4755..b07afc4 100644 --- a/layers/hyp_layers.py +++ b/layers/hyp_layers.py @@ -60,10 +60,10 @@ class HyperbolicGraphConvolution(nn.Module): Hyperbolic graph convolution layer. """ - def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att): + def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att, local_agg): super(HyperbolicGraphConvolution, self).__init__() self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias) - self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att) + self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att, local_agg) self.hyp_act = HypAct(manifold, c_in, c_out, act) def forward(self, input): @@ -119,13 +119,14 @@ class HypAgg(Module): Hyperbolic aggregation layer. """ - def __init__(self, manifold, c, in_features, dropout, use_att): + def __init__(self, manifold, c, in_features, dropout, use_att, local_agg): super(HypAgg, self).__init__() self.manifold = manifold self.c = c self.in_features = in_features self.dropout = dropout + self.local_agg = local_agg self.use_att = use_att if self.use_att: self.att = DenseAtt(in_features, dropout) @@ -133,8 +134,19 @@ def __init__(self, manifold, c, in_features, dropout, use_att): def forward(self, x, adj): x_tangent = self.manifold.logmap0(x, c=self.c) if self.use_att: - adj_att = self.att(x_tangent, adj) - support_t = torch.matmul(adj_att, x_tangent) + if self.local_agg: + x_local_tangent = [] + for i in range(x.size(0)): + x_local_tangent.append(self.manifold.logmap(x[i], x, c=self.c)) + x_local_tangent = torch.stack(x_local_tangent, dim=0) + adj_att = self.att(x_tangent, adj) + att_rep = adj_att.unsqueeze(-1) * x_local_tangent + support_t = torch.sum(adj_att.unsqueeze(-1) * x_local_tangent, dim=1) + output = self.manifold.proj(self.manifold.expmap(x, support_t, c=self.c), c=self.c) + return output + else: + adj_att = self.att(x_tangent, adj) + support_t = torch.matmul(adj_att, x_tangent) else: support_t = torch.spmm(adj, x_tangent) output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c) diff --git a/models/encoders.py b/models/encoders.py index 6505077..71874b8 100644 --- a/models/encoders.py +++ b/models/encoders.py @@ -108,7 +108,7 @@ def __init__(self, c, args): act = acts[i] hgc_layers.append( hyp_layers.HyperbolicGraphConvolution( - self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att + self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att, args.local_agg ) ) self.layers = nn.Sequential(*hgc_layers)