-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
115 lines (106 loc) · 3.2 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import lightning.pytorch as lp
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from crystalproject.module.predictor_module import PreModule
from crystalproject.data.map_data_module import MapDataModule
# 配置文件
module_config = {
"backbone":{
"name": "toponet",
"kwargs":{
"atom_embedding":{
"config_path": "/home/gwh/project/crystalProject/models/crystalProject/crystalproject/assets/atom_init.json"
},
"atom_hidden_channels": 128,
"atom_graph":{
"name": "cgcnn",
"kwargs":{
"num_layers": 3,
"edge_embedding": {"dmin": 0.0, "dmax": 5.0, "step": 0.2},
}
},
"cluster_hidden_channels": 128,
"cluster_graph":{
"name": "gcn",
"kwargs":{
"num_layers": 2,
}
},
"underling_network":{
"name": "cgcnn",
"kwargs":{
"num_layers": 3,
"edge_embedding": {"dmin": 0.0, "dmax": 15.0, "step": 0.2},
}
}
}
},
"predictor":{
"targets": {"absolute methane uptake high P [v STP/v]":0.5, "absolute methane uptake low P [v STP/v]":0.5},
"heads":[
{
"name": "regression",
"kwargs":{
"in_channels": 137,
"out_channels": 2,
"targets": ["absolute methane uptake high P [v STP/v]", "absolute methane uptake low P [v STP/v]"],
"descriptors": ["atom_graph_embedding", "vol", "rho", "di", "df", "dif", "asa", "av", "nasa", "nav"]
}
},
]
},
"optimizers":{
"name": "Adam",
"kwargs":{
"lr":5e-4,
"weight_decay": 0.1
},
},
"scheduler":{
"name": "StepLR",
"kwargs":{
"step_size": 500,
},
},
"loss":{
"name": "mse",
},
"criterion":{
"name": "mae",
}
}
data_config = {
"dataset":{
"name": "CrystalTopoDataset",
"kwargs":{
"root_dir": "/home/gwh/project/crystalProject/DATA/cofs_Methane/process",
"descriptor_index": ["absolute methane uptake high P [v STP/v]", "absolute methane uptake low P [v STP/v]", "vol", "rho", "di", "df", "dif", "asa", "av", "nasa", "nav"],
}
},
"dataloader":{
"batch_size": 16,
"num_workers": 1,
"pin_memory": True,
}
}
trainer_config = {
"max_epochs": 500,
"min_epochs": 100
}
# 回调函数
early_stop = EarlyStopping(
monitor="val_mae",
mode="min",
)
model_checkpoint = ModelCheckpoint(
filename="model-{epoch:02d}-{val_criterion:.2f}",
save_top_k=3,
monitor="val_mae_absolute methane uptake high P [v STP/v]",
mode="min"
)
# 主流程
module = PreModule(**module_config)
data_module = MapDataModule(**data_config)
trainer = lp.Trainer(**trainer_config, devices=[2])
trainer.fit(module, data_module)
trainer.test(module, data_module)