-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
29 lines (26 loc) · 831 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import print_function
from __future__ import division
import torch
from torchvision import transforms
from PIL import Image
from efficientnet_pytorch import EfficientNet
file = "./test.png"
model = EfficientNet.from_pretrained("efficientnet-b7")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
trans = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img = trans(Image.open(file).convert("RGB"))
img = img.unsqueeze(0)
img = img.to(device)
with torch.no_grad():
output = model(img)
_, preds = torch.max(output, 1)
print(f"Test passed, class number of test image image is {int(preds)}")