-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create some_notes_on_testing_model.py
- Loading branch information
1 parent
4664679
commit 8576c65
Showing
1 changed file
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Sat Nov 19 17:13:41 2022 | ||
@author: wendytsai | ||
""" | ||
#from https://github.com/musikalkemist/pytorchforaudio/tree/main/03%20Making%20predictions | ||
|
||
|
||
import torch | ||
from train import model, download_mnist_datasets | ||
|
||
|
||
class_mapping = [ | ||
"0", | ||
"1", | ||
"2", | ||
] | ||
|
||
|
||
def predict(model, input, target, class_mapping): | ||
model.eval() | ||
with torch.no_grad(): | ||
predictions = model(input) | ||
# Tensor (1, 10) -> [ [0.1, 0.01, ..., 0.6] ] | ||
predicted_index = predictions[0].argmax(0) | ||
predicted = class_mapping[predicted_index] | ||
expected = class_mapping[target] | ||
return predicted, expected | ||
|
||
|
||
if __name__ == "__main__": | ||
# load back the model | ||
model = model() | ||
state_dict = torch.load("model.pth") | ||
model.load_state_dict(state_dict) | ||
|
||
# load MNIST validation dataset | ||
_, validation_data = download_mnist_datasets() | ||
|
||
# get a sample from the validation dataset for inference | ||
input, target = validation_data[0][0], validation_data[0][1] | ||
|
||
# make an inference | ||
predicted, expected = predict(model, input, target, | ||
class_mapping) | ||
print(f"Predicted: '{predicted}', expected: '{expected}'") | ||
|
||
|
||
|
||
# from https://learn.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-train-model | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
# Function to show the images | ||
def imageshow(img): | ||
img = img / 2 + 0.5 # unnormalize | ||
npimg = img.numpy() | ||
plt.imshow(np.transpose(npimg, (1, 2, 0))) | ||
plt.show() | ||
|
||
|
||
# Function to test the model with a batch of images and show the labels predictions | ||
def testBatch(): | ||
# get batch of images from the test DataLoader | ||
images, labels = next(iter(test_loader)) | ||
|
||
# show all images as one image grid | ||
imageshow(torchvision.utils.make_grid(images)) | ||
|
||
# Let's see what if the model identifiers the labels of those example | ||
outputs = model(images) | ||
|
||
# We got the probability for every 10 labels. The highest (max) probability should be correct label | ||
_, predicted = torch.max(outputs, 1) | ||
|
||
# Let's show the predicted labels on the screen to compare with the real ones | ||
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] | ||
for j in range(batch_size))) |