-
Notifications
You must be signed in to change notification settings - Fork 0
/
functs.py
92 lines (68 loc) · 2.72 KB
/
functs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.history = []
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.history.append(self.sum / self.count)
def update_simplesum(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
self.history.append(self.sum / self.count)
def train_fn(train_loader, model, criterion, optimizer, scheduler, device):
losses = AverageMeter()
accuracies = AverageMeter()
model.train()
for step, (images, labels, paths, xfeatures) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
xfeatures = xfeatures.to(device)
# with torch.set_grad_enabled(True):
y_preds = model(images, xfeatures)
loss = criterion(y_preds, labels)
preds = (y_preds == y_preds.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
# statistics
losses.update(loss.item(), images.size(0))
how_many_correct = torch.sum(torch.all(torch.eq(preds, labels), dim=1))
accuracies.update_simplesum(how_many_correct.item(), images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print(f'Train Loss: {losses.avg:.4f} Acc: {accuracies.avg:.4f}')
return losses.history, accuracies.history
def valid_fn(valid_loader, model, criterion, device):
losses = AverageMeter()
accuracies = AverageMeter()
model.eval()
for step, (images, labels, paths, xfeatures) in enumerate(valid_loader):
images = images.to(device)
labels = labels.to(device)
xfeatures = xfeatures.to(device)
# compute loss
with torch.no_grad():
y_preds = model(images, xfeatures)
loss = criterion(y_preds, labels)
preds = (y_preds == y_preds.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
# statistics
losses.update(loss.item(), images.size(0))
how_many_correct = torch.sum(torch.all(torch.eq(preds, labels), dim=1))
accuracies.update_simplesum(how_many_correct.item(), images.size(0))
print(f'Val Loss: {losses.avg:.4f} Acc: {accuracies.avg:.4f}')
return losses.history, accuracies.history