-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_train.py
93 lines (80 loc) · 2.32 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import os
os.getcwd()
from train_rbd import train
import networks
name = "binding"
parser = argparse.ArgumentParser()
#=======================================================================================================================
parser.add_argument(
"--model_name",
"-mn",
help="network for training."
"-mn for ",
type=str,
default="XBCR_net",
required=False,
)
parser.add_argument(
"--data_name",
"-dn",
help="data name for training."
"-dn for ",
type=str,
default=name,
required=False,
)
parser.add_argument(
"--type",
help="Training type, full or rbd or multi",
# default="full",
default="rbd",
type=str,
required=False,
)
parser.add_argument(
"--model_num",
help="The model number.",
type=int,
default=0,
)
parser.add_argument(
"--max_epochs",
help="The maximum number of epochs, -1 means following configuration.",
type=int,
default=1000,
)
parser.add_argument(
"--include_light",
help="include light or not.",
type=int,
default=1,
)
parser.add_argument(
"--restore_pretrain",
help="restore pre-trained model or not.",
type=int,
default=1,
)
#=======================================================================================================================
args = parser.parse_args()
#=======================================================================================================================
model_num=args.model_num
batch_size=12
nb_epochs1 = args.max_epochs
model_name=args.model_name
data_name=args.data_name
include_light=args.include_light
restore_pre_train =args.restore_pretrain
# network setting
net_core = networks.get_net(model_name)
os.getcwd()
print(os.getcwd())
model_path=os.path.join('.','models',data_name,data_name+'-'+model_name,'model')
data_path=os.path.join('.','data',data_name)
print('model:',model_path,' data:',data_path)
print(os.path.abspath(data_path))
# training data
pos_path = os.path.join(data_path,'exper')
neg_path = os.path.join(data_path,'nonexp')
train(net_core=net_core, model_path=model_path,model_num=model_num,include_light=include_light, pos_path=pos_path, neg_path=neg_path, batch_size=batch_size, nb_epochs1=nb_epochs1, restore_pre_train =restore_pre_train )