-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
152 lines (129 loc) · 4.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import torch.nn as nn
import torch.optim as optim
import tensorboardX
import os
import random
import numpy as np
from train import train_epoch
from torch.utils.data import DataLoader
from validation import val_epoch
from opts import parse_opts
from model import generate_model
from torch.optim import lr_scheduler
from dataset import get_training_set, get_validation_set
from mean import get_mean, get_std
from torchvision import transforms
from target_transforms import ClassLabel, VideoID
import time
def resume_model(opt, model, optimizer):
""" Resume model
"""
checkpoint = torch.load(opt.resume_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print("Model Restored from Epoch {}".format(checkpoint['epoch']))
start_epoch = checkpoint['epoch'] + 1
return start_epoch
def get_loaders(opt):
""" Make dataloaders for train and validation sets
"""
# train loader
opt.mean = get_mean(opt.norm_value, dataset=opt.mean_dataset)
# opt.std = get_std()
if opt.no_mean_norm and not opt.std_norm:
norm_method = transforms.Normalize([0, 0, 0], [1, 1, 1])
elif not opt.std_norm:
norm_method = transforms.Normalize(opt.mean, [1, 1, 1])
else:
norm_method = transforms.Normalize(opt.mean, opt.std)
spatial_transform = transforms.Compose([
# crop_method,
transforms.Scale((opt.sample_size, opt.sample_size)),
#grayscale
# transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
norm_method
])
temporal_transform = None #TemporalRandomCrop(16)
target_transform = ClassLabel()
training_data = get_training_set(opt, spatial_transform,
temporal_transform, target_transform)
train_loader = torch.utils.data.DataLoader(
training_data,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers,
pin_memory=True)
# validation loader
validation_data = get_validation_set(opt, spatial_transform,
temporal_transform, target_transform)
val_loader = torch.utils.data.DataLoader(
validation_data,
batch_size=opt.batch_size,
shuffle=False,
num_workers=opt.num_workers,
pin_memory=True)
return train_loader, val_loader
def main_worker():
opt = parse_opts()
print(opt)
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# CUDA for PyTorch
device = torch.device("cuda" if opt.use_cuda else "cpu")
# tensorboard
summary_writer = tensorboardX.SummaryWriter(log_dir='tf_logs')
# defining model
model = generate_model(opt, device)
# get data loaders
train_loader, val_loader = get_loaders(opt)
# optimizer
crnn_params = list(model.parameters())
optimizer = torch.optim.Adam(crnn_params, lr=opt.lr_rate, weight_decay=opt.weight_decay)
# scheduler = lr_scheduler.ReduceLROnPlateau(
# optimizer, 'min', patience=opt.lr_patience)
criterion = nn.CrossEntropyLoss()
# resume model
if opt.resume_path:
start_epoch = resume_model(opt, model, optimizer)
else:
start_epoch = 1
# start training
best_val_accuracy = 0
best_state = None
best_epoch = None
best_val_loss = None
for epoch in range(start_epoch, opt.n_epochs + 1):
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, epoch, opt.log_interval, device)
val_loss, val_acc = val_epoch(
model, val_loader, criterion, device)
# saving weights to checkpoint
if (epoch) % opt.save_interval == 0:
# scheduler.step(val_loss)
# write summary
summary_writer.add_scalar(
'losses/train_loss', train_loss, global_step=epoch)
summary_writer.add_scalar(
'losses/val_loss', val_loss, global_step=epoch)
summary_writer.add_scalar(
'acc/train_acc', train_acc * 100, global_step=epoch)
summary_writer.add_scalar(
'acc/val_acc', val_acc * 100, global_step=epoch)
# timestamp = time.strftime('%b-%d-%Y_%H%M', time.localtime())
state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}
if(val_acc >= best_val_accuracy):
best_state = state
best_epoch = epoch
best_val_loss = val_loss
best_val_accuracy = val_acc
if not os.path.exists("snapshots/"+opt.model):
os.mkdir("snapshots/" + opt.model)
timestamp = time.strftime('%b-%d-%Y_%H%M', time.localtime())
torch.save(best_state, os.path.join("snapshots/" + opt.model, f'{opt.model}-Epoch-{best_epoch}-Loss-{best_val_loss}_{timestamp}.pth'))
print("Best model saved with val accuracy ", best_val_accuracy)
if __name__ == "__main__":
main_worker()