-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
201 lines (162 loc) · 8.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
""" 학습 코드
TODO:
NOTES:
REFERENCE:
* MNC 코드 템플릿 train.py
UPDATED:
"""
# WandB : 라이브러리 로드
import wandb
import os
import random
from tqdm import tqdm
import torch.nn as nn
from datetime import datetime, timezone, timedelta
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from modules.metrics import get_metric_fn
from modules.dataset import CustomDataset , TestDataset
from modules.trainer import Trainer
from modules.utils import load_yaml, save_yaml, get_logger, make_directory
from modules.earlystoppers import LossEarlyStopper
from modules.recorders import PerformanceRecorder
import torch
from model.model import PestClassifier
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
# DEBUG
DEBUG = False
# CONFIG
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_PROJECT_DIR = os.path.dirname(PROJECT_DIR)
DATA_DIR = '../shared/Split/'
TRAIN_CONFIG_PATH = os.path.join(PROJECT_DIR, 'config/train_config.yml')
config = load_yaml(TRAIN_CONFIG_PATH)
# SEED
RANDOM_SEED = config['SEED']['random_seed']
# TRAIN
EPOCHS = config['TRAIN']['num_epochs']
BATCH_SIZE = config['TRAIN']['batch_size']
LEARNING_RATE = config['TRAIN']['learning_rate']
EARLY_STOPPING_PATIENCE = config['TRAIN']['early_stopping_patience']
MODEL = config['TRAIN']['model']
OPTIMIZER = config['TRAIN']['optimizer']
SCHEDULER = config['TRAIN']['scheduler']
MOMENTUM = config['TRAIN']['momentum']
WEIGHT_DECAY = config['TRAIN']['weight_decay']
LOSS_FN = config['TRAIN']['loss_function']
METRIC_FN = config['TRAIN']['metric_function']
INPUT_SHAPE = config['TRAIN']['input_shape']
INPUT_SHAPE = (INPUT_SHAPE, INPUT_SHAPE)
# TRAIN SERIAL
KST = timezone(timedelta(hours=9))
TRAIN_TIMESTAMP = datetime.now(tz=KST).strftime("%Y%m%d%H%M%S")
TRAIN_SERIAL = f'{MODEL}_{TRAIN_TIMESTAMP}' if DEBUG is not True else 'DEBUG'
# PERFORMANCE RECORD
PERFORMANCE_RECORD_DIR = os.path.join(PROJECT_DIR, 'results', 'train', TRAIN_SERIAL)
PERFORMANCE_RECORD_COLUMN_NAME_LIST = config['PERFORMANCE_RECORD']['column_list']
# TRAIN CONFIG LIST
def train_config(train_config_list):
TRAIN_CONFIG_PATH = os.path.join(PROJECT_DIR, train_config_list)
config = load_yaml(TRAIN_CONFIG_PATH)
INPUT_SHAPE = config['TRAIN']['input_shape']
INPUT_SHAPE = (INPUT_SHAPE, INPUT_SHAPE)
return INPUT_SHAPE
if __name__ == '__main__':
# train_config_list에 실험 할 config들을 넣어 두고 실험 진행 하기
train_config_list = ['config/train_config.yml']
for i in train_config_list:
INPUT_SHAPE = train_config(i)
# WandB : wandb 세팅
wandb.init(project='eff-tomato_disease_classification', entity='benseo', config={"num_epochs": config['TRAIN']['num_epochs'], "batch_size": config['TRAIN']['batch_size'], "learning_rate": config['TRAIN']['learning_rate'], "early_stopping_patience": config['TRAIN']['early_stopping_patience'], "model": config['TRAIN']['model'], "layer": config['TRAIN']['layer'], "img_aug": config['TRAIN']['img_aug'], "softmax": config['TRAIN']['softmax'], "initialization": config['TRAIN']['initialization']}) # 실험 init 설정
wandb.run.name = config['TRAIN']['model'] + '-layer(' + config['TRAIN']['layer'] + ')-early_stopping(' + str(config['TRAIN']['early_stopping_patience']) + ')-img_aug(' + config['TRAIN']['img_aug'] + ')-softmax(' + config['TRAIN']['softmax'] + ')-' + config['TRAIN']['initialization'] # 실험 이름 설정
wandb.run.save()
# Set random seed
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set train result directory
make_directory(PERFORMANCE_RECORD_DIR)
# Set system logger
system_logger = get_logger(name='train', file_path=os.path.join(PERFORMANCE_RECORD_DIR, 'train_log.log'))
# Load dataset & dataloader
train_dataset = CustomDataset(data_dir=DATA_DIR, mode='train', input_shape=INPUT_SHAPE)
validation_dataset = CustomDataset(data_dir=DATA_DIR, mode='val', input_shape=INPUT_SHAPE)
# test_dataset = CustomDataset(data_dir=DATA_DIR, mode='test', input_shape=INPUT_SHAPE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# print('Train set samples:',len(train_dataset), 'Val set samples:', len(validation_dataset), 'Test set samples:', len(test_dataset))
print('Train set samples:',len(train_dataset))
# import pdb;pdb.set_trace()
# size 128 128 3
# Load Model
model = PestClassifier(num_class=train_dataset.class_num).to(device)
# WandB : 모델 gradient 추적
wandb.watch(model) # 모델 gradient 추적
# Set optimizer, scheduler, loss function, metric function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e5, max_lr=0.0001, epochs=EPOCHS, steps_per_epoch=len(train_dataloader)) # 원 사이클을 돌며 lr을 찾아 감
criterion = nn.CrossEntropyLoss() # pytorch의 crossentropyloss를 쓰면 softmax를 안 써도 됨 (기본 장착 되어 있음)
metric_fn = get_metric_fn
# Set trainer
trainer = Trainer(criterion, model, device, metric_fn, optimizer, scheduler, logger=system_logger)
# Set earlystopper
early_stopper = LossEarlyStopper(patience=EARLY_STOPPING_PATIENCE, verbose=True, logger=system_logger)
# Set performance recorder
key_column_value_list = [
TRAIN_SERIAL,
TRAIN_TIMESTAMP,
MODEL,
OPTIMIZER,
LOSS_FN,
METRIC_FN,
EARLY_STOPPING_PATIENCE,
BATCH_SIZE,
EPOCHS,
LEARNING_RATE,
WEIGHT_DECAY,
RANDOM_SEED]
performance_recorder = PerformanceRecorder(column_name_list=PERFORMANCE_RECORD_COLUMN_NAME_LIST,
record_dir=PERFORMANCE_RECORD_DIR,
key_column_value_list=key_column_value_list,
logger=system_logger,
model=model,
optimizer=optimizer,
scheduler=scheduler)
# Train
save_yaml(os.path.join(PERFORMANCE_RECORD_DIR, 'train_config.yaml'), config)
criterion = 1E+8
for epoch_index in tqdm(range(EPOCHS)):
trainer.train_epoch(train_dataloader, epoch_index)
trainer.validate_epoch(validation_dataloader, epoch_index, 'val')
# Performance record - csv & save elapsed_time
performance_recorder.add_row(epoch_index=epoch_index,
train_loss=trainer.train_mean_loss,
validation_loss=trainer.val_mean_loss,
train_score=trainer.train_score,
validation_score=trainer.validation_score)
# WandB : train loss, train score, validation loss, validation score 추적
wandb.log({"train_loss":trainer.train_mean_loss, "validation_loss": trainer.val_mean_loss, "train_score":trainer.train_score, "validation_score": trainer.validation_score})
# Performance record - plot
performance_recorder.save_performance_plot(final_epoch=epoch_index)
# early_stopping check
early_stopper.check_early_stopping(loss=trainer.val_mean_loss)
if early_stopper.stop:
print('Early stopped')
break
print("Val Mean Loss : ",trainer.val_mean_loss)
# print("Criterion : ", criterion)
# import pdb;pdb.set_trace()
# if trainer.val_mean_loss < criterion:
# val_mean_loss를 기준 점으로 early stop을 걸게 함
if trainer.val_mean_loss < criterion:
criterion = trainer.val_mean_loss
performance_recorder.weight_path = os.path.join(PERFORMANCE_RECORD_DIR, 'best.pt')
performance_recorder.save_weight()
print(f'{epoch_index} model saved')
print('----------------------------------')