Skip to content

Commit

Permalink
added att_0 back to aggregation layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ines-chami committed Jul 22, 2020
1 parent ff621d6 commit dc48b94
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
3 changes: 2 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
'act': ('relu', 'which activation function to use (or None for no activation)'),
'n-heads': (4, 'number of attention heads for graph attention networks, must be a divisor dim'),
'alpha': (0.2, 'alpha for leakyrelu in graph attention networks'),
'double-precision': ('0', 'whether to use double precision')
'double-precision': ('0', 'whether to use double precision'),
'use-att': (0, 'whether to use hyperbolic attention or not'),
},
'data_config': {
'dataset': ('cora', 'which dataset to use'),
Expand Down
3 changes: 1 addition & 2 deletions layers/att_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@


class DenseAtt(nn.Module):
def __init__(self, in_features, dropout, act):
def __init__(self, in_features, dropout):
super(DenseAtt, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(2 * in_features, 1, bias=True)
self.act = act
self.in_features = in_features

def forward (self, x, adj):
Expand Down
26 changes: 16 additions & 10 deletions layers/hyp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

from layers.att_layers import DenseAtt


def get_dim_act_curv(args):
Expand Down Expand Up @@ -38,7 +39,6 @@ def get_dim_act_curv(args):
return dims, acts, curvatures



class HNNLayer(nn.Module):
"""
Hyperbolic neural networks layer.
Expand All @@ -60,10 +60,10 @@ class HyperbolicGraphConvolution(nn.Module):
Hyperbolic graph convolution layer.
"""

def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias):
def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att):
super(HyperbolicGraphConvolution, self).__init__()
self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
self.agg = HypAgg(manifold, c_in, out_features, dropout)
self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att)
self.hyp_act = HypAct(manifold, c_in, c_out, act)

def forward(self, input):
Expand Down Expand Up @@ -100,18 +100,17 @@ def forward(self, x):
drop_weight = F.dropout(self.weight, self.dropout, training=self.training)
mv = self.manifold.mobius_matvec(drop_weight, x, self.c)
res = self.manifold.proj(mv, self.c)
if self.use_bias:
if self.use_bias:
bias = self.manifold.proj_tan0(self.bias.view(1, -1), self.c)
hyp_bias = self.manifold.expmap0(bias, self.c)
hyp_bias = self.manifold.proj(hyp_bias, self.c)
res = self.manifold.mobius_add(res, hyp_bias, c=self.c)
res = self.manifold.proj(res, self.c)
return res


def extra_repr(self):
return 'in_features={}, out_features={}, c={}'.format(
self.in_features, self.out_features, self.c
self.in_features, self.out_features, self.c
)


Expand All @@ -120,17 +119,24 @@ class HypAgg(Module):
Hyperbolic aggregation layer.
"""

def __init__(self, manifold, c, in_features, dropout):
def __init__(self, manifold, c, in_features, dropout, use_att):
super(HypAgg, self).__init__()
self.manifold = manifold
self.c = c

self.in_features = in_features
self.dropout = dropout
self.use_att = use_att
if self.use_att:
self.att = DenseAtt(in_features, dropout)

def forward(self, x, adj):
x_tangent = self.manifold.logmap0(x, c=self.c)
support_t = torch.spmm(adj, x_tangent)
if self.use_att:
adj_att = self.att(x_tangent, adj)
support_t = torch.matmul(adj_att, x_tangent)
else:
support_t = torch.spmm(adj, x_tangent)
output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
return output

Expand All @@ -157,5 +163,5 @@ def forward(self, x):

def extra_repr(self):
return 'c_in={}, c_out={}'.format(
self.c_in, self.c_out
self.c_in, self.c_out
)
2 changes: 1 addition & 1 deletion models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, c, args):
act = acts[i]
hgc_layers.append(
hyp_layers.HyperbolicGraphConvolution(
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias
self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att
)
)
self.layers = nn.Sequential(*hgc_layers)
Expand Down

0 comments on commit dc48b94

Please sign in to comment.