diff --git a/config.py b/config.py index a22d1a8..d45774d 100644 --- a/config.py +++ b/config.py @@ -39,7 +39,8 @@ 'act': ('relu', 'which activation function to use (or None for no activation)'), 'n-heads': (4, 'number of attention heads for graph attention networks, must be a divisor dim'), 'alpha': (0.2, 'alpha for leakyrelu in graph attention networks'), - 'double-precision': ('0', 'whether to use double precision') + 'double-precision': ('0', 'whether to use double precision'), + 'use-att': (0, 'whether to use hyperbolic attention or not'), }, 'data_config': { 'dataset': ('cora', 'which dataset to use'), diff --git a/layers/att_layers.py b/layers/att_layers.py index 21ffb96..8414d8d 100644 --- a/layers/att_layers.py +++ b/layers/att_layers.py @@ -6,11 +6,10 @@ class DenseAtt(nn.Module): - def __init__(self, in_features, dropout, act): + def __init__(self, in_features, dropout): super(DenseAtt, self).__init__() self.dropout = dropout self.linear = nn.Linear(2 * in_features, 1, bias=True) - self.act = act self.in_features = in_features def forward (self, x, adj): diff --git a/layers/hyp_layers.py b/layers/hyp_layers.py index 721d2a2..45e4755 100644 --- a/layers/hyp_layers.py +++ b/layers/hyp_layers.py @@ -6,7 +6,8 @@ import torch.nn.functional as F import torch.nn.init as init from torch.nn.modules.module import Module -from torch.nn.parameter import Parameter + +from layers.att_layers import DenseAtt def get_dim_act_curv(args): @@ -38,7 +39,6 @@ def get_dim_act_curv(args): return dims, acts, curvatures - class HNNLayer(nn.Module): """ Hyperbolic neural networks layer. @@ -60,10 +60,10 @@ class HyperbolicGraphConvolution(nn.Module): Hyperbolic graph convolution layer. """ - def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias): + def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att): super(HyperbolicGraphConvolution, self).__init__() self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias) - self.agg = HypAgg(manifold, c_in, out_features, dropout) + self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att) self.hyp_act = HypAct(manifold, c_in, c_out, act) def forward(self, input): @@ -100,18 +100,17 @@ def forward(self, x): drop_weight = F.dropout(self.weight, self.dropout, training=self.training) mv = self.manifold.mobius_matvec(drop_weight, x, self.c) res = self.manifold.proj(mv, self.c) - if self.use_bias: + if self.use_bias: bias = self.manifold.proj_tan0(self.bias.view(1, -1), self.c) hyp_bias = self.manifold.expmap0(bias, self.c) hyp_bias = self.manifold.proj(hyp_bias, self.c) res = self.manifold.mobius_add(res, hyp_bias, c=self.c) res = self.manifold.proj(res, self.c) return res - def extra_repr(self): return 'in_features={}, out_features={}, c={}'.format( - self.in_features, self.out_features, self.c + self.in_features, self.out_features, self.c ) @@ -120,17 +119,24 @@ class HypAgg(Module): Hyperbolic aggregation layer. """ - def __init__(self, manifold, c, in_features, dropout): + def __init__(self, manifold, c, in_features, dropout, use_att): super(HypAgg, self).__init__() self.manifold = manifold self.c = c self.in_features = in_features self.dropout = dropout + self.use_att = use_att + if self.use_att: + self.att = DenseAtt(in_features, dropout) def forward(self, x, adj): x_tangent = self.manifold.logmap0(x, c=self.c) - support_t = torch.spmm(adj, x_tangent) + if self.use_att: + adj_att = self.att(x_tangent, adj) + support_t = torch.matmul(adj_att, x_tangent) + else: + support_t = torch.spmm(adj, x_tangent) output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c) return output @@ -157,5 +163,5 @@ def forward(self, x): def extra_repr(self): return 'c_in={}, c_out={}'.format( - self.c_in, self.c_out + self.c_in, self.c_out ) diff --git a/models/encoders.py b/models/encoders.py index 24c951c..6505077 100644 --- a/models/encoders.py +++ b/models/encoders.py @@ -108,7 +108,7 @@ def __init__(self, c, args): act = acts[i] hgc_layers.append( hyp_layers.HyperbolicGraphConvolution( - self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias + self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att ) ) self.layers = nn.Sequential(*hgc_layers)