-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathStarUNetgpu.py
127 lines (104 loc) · 4.76 KB
/
StarUNetgpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn.functional as F
from torch_sparse import spspmm
from torch_geometric.nn import GCNConv
from torch_geometric.utils import (add_self_loops, sort_edge_index,
remove_self_loops)
from star_pool_gpu import StarPooling
class GraphUNet(torch.nn.Module):
r"""The Graph U-Net model from the `"Graph U-Nets"
<https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
architecture with graph pooling and unpooling operations.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
out_channels (int): Size of each output sample.
depth (int): The depth of the U-Net architecture.
pool_ratios (float or [float], optional): Graph pooling ratio for each
depth. (default: :obj:`0.5`)
sum_res (bool, optional): If set to :obj:`False`, will use
concatenation for integration of skip connections instead
summation. (default: :obj:`True`)
act (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.nn.functional.relu`)
"""
def __init__(self, in_channels, hidden_channels, out_channels, depth,
sum_res=True, act=F.relu):
super(GraphUNet, self).__init__()
assert depth >= 1
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.depth = depth
self.act = act
self.sum_res = sum_res
channels = hidden_channels
self.down_convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.down_convs.append(GCNConv(in_channels, channels, improved=True))
for i in range(depth):
self.pools.append(StarPooling(channels))
self.down_convs.append(GCNConv(channels, channels, improved=True))
in_channels = channels if sum_res else 2 * channels
self.up_convs = torch.nn.ModuleList()
for i in range(depth - 1):
self.up_convs.append(GCNConv(in_channels, channels, improved=True))
self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
self.reset_parameters()
def reset_parameters(self):
for conv in self.down_convs:
conv.reset_parameters()
for pool in self.pools:
pool.reset_parameters()
for conv in self.up_convs:
conv.reset_parameters()
def forward(self, x, edge_index, batch=None):
""""""
if batch is None:
batch = edge_index.new_zeros(x.size(0))
device = torch.device(x.device)
edge_weight = x.new_ones(edge_index.size(1)).to(device)
x = self.down_convs[0](x, edge_index, edge_weight)
x = self.act(x)
xs = [x]
edge_indices = [edge_index]
edge_weights = [edge_weight]
perms = []
for i in range(1, self.depth + 1):
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0))
print("Pooling ", i)
x, edge_index, edge_weight, batch, perm = self.pools[i - 1](
x, edge_index, edge_weight, batch)
x = self.down_convs[i](x, edge_index, edge_weight)
x = self.act(x)
if i < self.depth:
xs += [x]
edge_indices += [edge_index]
edge_weights += [edge_weight]
perms += [perm]
for i in range(self.depth):
j = self.depth - 1 - i
res = xs[j]
edge_index = edge_indices[j]
edge_weight = edge_weights[j]
perm = perms[j]
up = torch.zeros_like(res)
up[perm] = x
x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
x = self.up_convs[i](x, edge_index, edge_weight)
x = self.act(x) if i < self.depth - 1 else x
return x
def augment_adj(self, edge_index, edge_weight, num_nodes):
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
num_nodes)
edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
edge_weight, num_nodes, num_nodes,
num_nodes)
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight
def __repr__(self):
return '{}({}, {}, {}, depth={})'.format(
self.__class__.__name__, self.in_channels, self.hidden_channels,
self.out_channels, self.depth)