-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
35 lines (27 loc) · 924 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Training a SEAL-CI model."""
import torch
from utils import tab_printer
from subgraph import Subgraph_Learning
from param_parser import parameter_parser
import torch._utils
def main():
"""
Parsing command line parameters, reading data.
Fitting and scoring a SEAL-CI model.
"""
try:
torch._utils._rebuild_tensor_v2
except AttributeError:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
args = parameter_parser()
tab_printer(args)
trainer = Subgraph_Learning(args)
trainer.fit()
trainer.test()
if __name__ == '__main__':
main()