Skip to content

Commit

Permalink
Create some_notes_on_testing_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chwendytsai authored Nov 19, 2022
1 parent 4664679 commit 8576c65
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions some_notes_on_testing_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 19 17:13:41 2022
@author: wendytsai
"""
#from https://github.com/musikalkemist/pytorchforaudio/tree/main/03%20Making%20predictions


import torch
from train import model, download_mnist_datasets


class_mapping = [
"0",
"1",
"2",
]


def predict(model, input, target, class_mapping):
model.eval()
with torch.no_grad():
predictions = model(input)
# Tensor (1, 10) -> [ [0.1, 0.01, ..., 0.6] ]
predicted_index = predictions[0].argmax(0)
predicted = class_mapping[predicted_index]
expected = class_mapping[target]
return predicted, expected


if __name__ == "__main__":
# load back the model
model = model()
state_dict = torch.load("model.pth")
model.load_state_dict(state_dict)

# load MNIST validation dataset
_, validation_data = download_mnist_datasets()

# get a sample from the validation dataset for inference
input, target = validation_data[0][0], validation_data[0][1]

# make an inference
predicted, expected = predict(model, input, target,
class_mapping)
print(f"Predicted: '{predicted}', expected: '{expected}'")



# from https://learn.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-train-model

import matplotlib.pyplot as plt
import numpy as np

# Function to show the images
def imageshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()


# Function to test the model with a batch of images and show the labels predictions
def testBatch():
# get batch of images from the test DataLoader
images, labels = next(iter(test_loader))

# show all images as one image grid
imageshow(torchvision.utils.make_grid(images))

# Let's see what if the model identifiers the labels of those example
outputs = model(images)

# We got the probability for every 10 labels. The highest (max) probability should be correct label
_, predicted = torch.max(outputs, 1)

# Let's show the predicted labels on the screen to compare with the real ones
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(batch_size)))

0 comments on commit 8576c65

Please sign in to comment.