-
Notifications
You must be signed in to change notification settings - Fork 1
/
solver.py
151 lines (110 loc) · 4 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import datetime
from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup, CosineAnnealingScheduler
from ignite.engine import Events, Engine
from ignite.handlers import Timer, ModelCheckpoint
from ignite import engine
from ignite.metrics import Loss, Accuracy
from ignite.utils import convert_tensor
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
# helper functions
def prepare_batch(batch, device=None):
x, y = batch
return (convert_tensor(x, device=device), convert_tensor(y, device=device))
def create_trainer(model, optimizer, loss_fn, device=None):
def _update(engine, batch):
model.train()
optimizer.zero_grad()
inputs, labels = prepare_batch(batch, device=device)
preds = model(inputs, labels)
loss = loss_fn(preds, labels)
loss.backward()
optimizer.step()
return loss.item()
return Engine(_update)
def create_evaluator(model, metrics, device=None):
def _inference(engine, batch):
model.eval()
with torch.no_grad():
inputs, labels = prepare_batch(batch, device=device)
preds = model(inputs)
return preds, labels
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
# config
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# save model
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
snapshots = './snapshots/{}'.format(t)
if not os.path.exists(snapshots):
os.makedirs(snapshots)
init_lr = 1e-1
end_lr = 1e-5
batch = 512
epochs = 20
workernum = 4
num_classes = 1000
# dataloader
# nets
class Net(nn.Module):
pass
model = Net()
print(model)
model.to(device)
# multi-gpus
if torch.cuda.device_count():
print('==================== Use {} GPUs ===================='.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
# loss function
loss_fn = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=5e-4)
# scheduler
scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, 4*len(trainloader), cycle_mult=1.5, start_value_mult=0.1)
scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=0., warmup_end_value=init_lr, warmup_duration=len(trainloader))
# create trainer
trainer = create_trainer(model, optimizer, loss_fn, device=device)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
# add timer for each iteration
timer = Timer(average=False)
# logging training loss
def log_loss(engine):
i = engine.state.iteration
e = engine.state.epoch
if i % 100 == 0:
print('[Iters {:0>7d}/{:0>2d}, {:.2f}s/100 iters, lr={:.4E}] loss={:.4f}'.format(i, e, timer.value(), optimizer.param_groups[0]['lr'], engine.state.output))
timer.reset()
trainer.add_event_handler(Events.ITERATION_COMPLETED, log_loss)
# Evaluation
metrics = {
'loss': Loss(loss_fn),
'acc': Accuracy()
}
def score_fn(engine):
acc = engine.state.metrics['acc']
return acc
evaluator = create_evaluator(model, metrics, device=device)
def log_metrics(engine):
metrics = evaluator.run(valoader).metrics
print('[INFO] Compute metrics...')
print(' Validation Results - Average Loss: {:.4f} | Accuracy: {:.4f}'.format(metrics['loss'], metrics['acc']))
print('[INFO] Complete metrics...')
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_metrics)
# save the model checkpoints
saver = ModelCheckpoint(snapshots, 'r101', n_saved=10, score_name='acc', score_function=score_fn)
evaluator.add_event_handler(Events.COMPLETED, saver, {'model': model.module})
# start training
print('[INFO] Start training...')
trainer.run(trainloader, epochs)
print('[INFO] Complete training...')