-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
117 lines (98 loc) · 3.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
setup model and datasets
"""
import torch
import torch.nn as nn
# from advertorch.utils import NormalizeByChannelMeanStd
# from advertorch.utils import NormalizeByChannelMeanStd
from pruner.pruner import NormalizeByChannelMeanStd
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from dataset import *
from models import *
__all__ = ["setup_model_dataset", "setup_model"]
def evaluate_cer(net, args, loader_=None):
criterion = nn.CrossEntropyLoss()
test_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
if args.dataset == "cifar10":
test_set = CIFAR10(
args.data, train=False, transform=test_transform, download=True
)
test_loader = DataLoader(
test_set,
batch_size=128,
shuffle=False,
num_workers=2,
pin_memory=True,
)
elif args.dataset == "cifar100":
test_set = CIFAR100(
args.data, train=False, transform=test_transform, download=True
)
test_loader = DataLoader(
test_set,
batch_size=128,
shuffle=False,
num_workers=2,
pin_memory=True,
)
elif args.dataset == "restricted_imagenet":
test_loader = loader_
correct = 0
total_loss = 0
total = 0 # number of samples
num_batch = len(test_loader)
use_cuda = True
net.cuda()
net.eval()
with torch.no_grad():
if isinstance(criterion, nn.CrossEntropyLoss):
for batch_idx, (inputs, targets) in enumerate(test_loader):
# print(inputs.size(0))
batch_size = inputs.size(0)
total += batch_size
inputs = Variable(inputs)
targets = Variable(targets)
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * batch_size
_, predicted = torch.max(outputs.data, 1)
correct += predicted.eq(targets).sum().item()
print("Correct %")
print(100 * correct / total)
misclassified = total - correct
print("Total Loss")
print(total_loss * 100 / total)
print(f"misclassified samples from {total}")
print(misclassified)
return misclassified
def setup_model(args):
if args.dataset == "cifar10":
classes = 10
normalization = NormalizeByChannelMeanStd(
mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
)
elif args.dataset == "cifar100":
classes = 100
normalization = NormalizeByChannelMeanStd(
mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
)
elif args.dataset == "restricted_imagenet":
classes = 14
if args.imagenet_arch:
if args.dataset == "restricted_imagenet":
classes = 14
model = model_dict[args.arch](num_classes=classes, imagenet=True)
else:
model = model_dict[args.arch](num_classes=classes)
if args.dataset != "restricted_imagenet":
model.normalize = normalization
return model