-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathex05_test.py
92 lines (77 loc) · 3.3 KB
/
ex05_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import DataLoader
from ex03_customdataset import CustomDataset
import pandas as pd
from torchvision import models
from ex04_main import FIX
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import rexnet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def test_main():
test_aug = A.Compose([
A.SmallestMaxSize(max_size= 224),
A.CenterCrop(width= 200, height= 200),
A.Normalize(mean=(0.485, 0.456, 0.406), std= (0.229, 0.224, 0.225)),
ToTensorV2()
])
test_dataset = CustomDataset("./dataset/test" , transform= test_aug)
test_loader = DataLoader(test_dataset, batch_size= 1, shuffle= False, num_workers= 2, pin_memory= True)
###### 수정해야 할 부분 !!!!!!!!!!!!!!!!!!!
model = models.mobilenet_v2(pretrained= False)
model.classifier[1] = nn.Linear(in_features=1280, out_features= 2)
model.load_state_dict(torch.load(f"./12nd.pt", map_location=device))
model.to(device)
# model = models.__dict__["resnet50"](pretrained= False)
# model.fc = nn.Linear(in_features= 2048, out_features= 2)
# model.load_state_dict(torch.load("./best0.pt", map_location=device))
# model.to(device)
# model = models.__dict__["vgg16"](pretrained=False)
# model.classifier[6] = nn.Linear(in_features= 4096, out_features= 2)
# model.load_state_dict(torch.load("./best0.pt", map_location=device))
# model.to(device)
# model = models.__dict__["resnet18"](pretrained=False)
# model.fc = nn.Linear(in_features= 512, out_features=2)
# model.load_state_dict(torch.load("./result/3rd.pt", map_location=device))
# model.to(device)
# model = rexnetv1.ReXNetV1()
# model.output[1] = nn.Conv2d(1280, 2, kernel_size=1, stride=1)
# model.load_state_dict(torch.load("./best0.pt"))
# model.to(device)
test(model, test_loader, device)
def acc_function(correct, total) :
acc = correct / total * 100
return acc
def test(model, data_loader, device) :
model.eval()
correct = 0
total = 0
y_pred, y_true = [], []
with torch.no_grad():
for i, (image, label) in enumerate(data_loader) :
images, labels = image.to(device), label.to(device)
output = model(images)
_, argmax = torch.max(output, 1)
total += images.size(0)
correct += (labels == argmax).sum().item()
argmax = argmax.data.cpu().numpy() # gpu에 할당된 tensor를 cpu 텐서로 변환
labels = labels.data.cpu().numpy() # gpu에 할당된 tensor를 cpu 텐서로 변환
y_pred.extend(argmax) # Save Prediction
y_true.extend(labels) # Save True
acc = acc_function(correct, total)
print(f"acc >> {acc}%" )
# Build confusion matrix
classes = ('bird', 'drone')
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes], columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
plt.savefig('./confunsion_matrix.png')
if __name__ == '__main__':
test_main()