-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
147 lines (120 loc) · 8.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn
import argparse
import os
from Datasets.Dataset_loader import get_dataset
from Model.CrossModalTransformer import Cross_Transformer_Network
from Model.USleep import USleep, _EncoderBlock, _DecoderBlock
from Model.Training import supervised_training,KD_online_training,KD_offline_training
def parse_option():
parser = argparse.ArgumentParser('Argument for training')
parser.add_argument('--project_path', type=str, default='./results', help='Path to store project results')
parser.add_argument('--data_path', type=str, default='',help='Path to the dataset file')
parser.add_argument('--train_data_list', nargs="+", default = ['[0,1,2,3,5,6,7]'] , help='Folds in the dataset for training')
parser.add_argument('--val_data_list', nargs="+", default = ['[8]'] , help='Folds in the dataset for validation')
parser.add_argument('--is_retrain', type=bool, default=False, help='To retrain a from saved checkpoint')
parser.add_argument('--is_student_pretrain', type=bool, default=False, help='To used pretrained model for student')
parser.add_argument('--is_teacher_pretrain', type=bool, default=False, help='To used pretrained model for teacher (online KD only')
parser.add_argument('--model_path', type=str, default="", help='Path to saved checkpoint for retraining')
parser.add_argument('--student_model_path', type=str, default="", help='Path to saved checkpoint for student initialization')
parser.add_argument('--signals', type=str, default = 'ear-eeg' ,choices=['ear-eeg', 'scalp-eeg'], help='signal type')
parser.add_argument('--model', type=str, default = 'USleep' ,choices=['USleep', 'CMT'], help='Model architecture')
#model parameters
parser.add_argument('--training_type', type=str, default = 'Knowledge_distillation' ,choices=['supervised', 'Knowledge_distillation'], help='training type')
parser.add_argument('--KD_type', type=str, default = 'online' ,choices=['offline', 'online'], help='Knowledge distillation type')
parser.add_argument('--d_model', type=int, default = 256, help='Embedding size of the CMT')
parser.add_argument('--dim_feedforward', type=int, default = 1024, help='No of neurons feed forward block')
parser.add_argument('--window_size', type=int, default = 50, help='Size of non-overlapping window')
parser.add_argument('--depth', type=int, default = 12, help='depth of USleep model')
#training parameters
parser.add_argument('--batch_size', type=int, default = 32 , help='Batch Size')
#For Optimizer
parser.add_argument('--lr', type=float, default = 0.001 , help='Learning rate')
parser.add_argument('--beta_1', type=float, default = 0.9 , help='beta 1 for adam optimizer')
parser.add_argument('--beta_2', type=float, default = 0.999 , help='beta 2 for adam optimizer')
parser.add_argument('--eps', type=float, default = 1e-9 , help='eps for adam optimizer')
parser.add_argument('--weight_decay', type=float, default = 0.0001 , help='weight_decay for adam optimizer')
parser.add_argument('--n_epochs', type=int, default = 100 , help='No of training epochs')
#Neptune
parser.add_argument('--is_neptune', type=bool, default=False, help='Is neptune used to track experiments')
parser.add_argument('--nep_project', type=str, default='', help='Neptune Project Name')
parser.add_argument('--nep_api', type=str, default='', help='Neptune API Token')
opt = parser.parse_args()
return opt
# experiment = 1
def main():
args = parse_option()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cuda = torch.cuda.is_available()
if cuda:
torch.backends.cudnn.benchmark = True
print(device)
print("<=============== Training arguments ===============>")
for arg in vars(args):
print(f"{arg} : {getattr(args,arg)}")
if not os.path.isdir(args.project_path):
os.makedirs(args.project_path)
print(f"Project directory created at {args.project_path}")
else:
print(f"Project directory available at {args.project_path}")
print("<=============== Getting Dataset ===============>")
train_data_loader,val_data_loader = get_dataset(args,device)
if not os.path.isdir(os.path.join(args.project_path,"model_check_points")):
os.makedirs(os.path.join(args.project_path,"model_check_points"))
if args.training_type == 'supervised':
if args.is_retrain:
print(f"Loading previous model from {args.model_path}")
Net = torch.load(f"{args.model_path}").to(device)
else:
print(f"Initializing Epoch cross modal transformer")
if args.model == 'CMT':
Net = Cross_Transformer_Network(d_model = args.d_model, dim_feedforward = args.dim_feedforward,window_size = args.window_size ).to(device)
elif args.model == 'USleep':
Net = USleep(in_chans=3,sfreq=200,depth=args.depth,with_skip_connection=True,n_classes=5,input_size_s=30,time_conv_size_s = 9/200,apply_softmax=False).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(Net.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),eps = args.eps, weight_decay = args.weight_decay)
supervised_training(Net, train_data_loader,val_data_loader, criterion, optimizer, args, device)
if args.training_type == 'Knowledge_distillation':
if args.KD_type == 'offline':
print(f"Loading previous model from {args.model_path} for teacher model")
Net_t = torch.load(f"{args.model_path}").to(device)
if args.is_student_pretrain:
print(f"Loading previous model from {args.student_model_path} for student model")
Net_s = torch.load(f"{args.student_model_path}").to(device)
else:
print(f"Initializing student model")
if args.model == 'CMT':
Net_s = Cross_Transformer_Network(d_model = args.d_model, dim_feedforward = args.dim_feedforward,window_size = args.window_size ).to(device)
elif args.model == 'USleep':
Net_s = USleep(in_chans=3,sfreq=200,depth=args.depth,with_skip_connection=True,n_classes=5,input_size_s=30,time_conv_size_s = 9/200,apply_softmax=False).to(device)
criterion_mse = nn.MSELoss()
criterion_ce = nn.CrossEntropyLoss()
optimizer_s = torch.optim.Adam(Net_s.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),eps = args.eps)
KD_offline_training(Net_s,Net_t,train_data_loader,val_data_loader,criterion_ce,criterion_mse,optimizer_s,args,device)
elif args.KD_type == 'online':
if args.is_student_pretrain:
print(f"Loading previous model from {args.student_model_path} for student model")
Net_s = torch.load(f"{args.student_model_path}").to(device)
else:
print(f"Initializing student model")
if args.model == 'CMT':
Net_s = Cross_Transformer_Network(d_model = args.d_model, dim_feedforward = args.dim_feedforward,window_size = args.window_size).to(device)
elif args.model == 'USleep':
Net_s = USleep(in_chans=3,sfreq=200,depth=args.depth,with_skip_connection=True,n_classes=5,input_size_s=30,time_conv_size_s = 9/200,apply_softmax=False).to(device)
print(Net_s)
if args.is_teacher_pretrain:
print(f"Loading previous model from {args.model_path} for teacher model")
Net_t = torch.load(f"{args.model_path}").to(device)
else:
print(f"Initializing teacher model")
if args.model == 'CMT':
Net_t = Cross_Transformer_Network(d_model = args.d_model, dim_feedforward = args.dim_feedforward,window_size = args.window_size ).to(device)
elif args.model == 'USleep':
Net_t = USleep(in_chans=3,sfreq=200,depth=args.depth,with_skip_connection=True,n_classes=5,input_size_s=30,time_conv_size_s = 9/200,apply_softmax=False).to(device)
criterion_mse = nn.MSELoss()
criterion_ce = nn.CrossEntropyLoss()
optimizer_s = torch.optim.Adam(Net_s.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),eps = args.eps)
optimizer_t = torch.optim.Adam(Net_t.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),eps = args.eps)
KD_online_training(Net_s,Net_t,train_data_loader,val_data_loader,criterion_ce,criterion_mse,optimizer_t, optimizer_s,args,device)
if __name__ == '__main__':
main()