From 8576c6546445d064b193cb7d77a35a519daedc47 Mon Sep 17 00:00:00 2001 From: CH Wendy Tsai <116213073+chwendytsai@users.noreply.github.com> Date: Sat, 19 Nov 2022 18:05:46 +0000 Subject: [PATCH] Create some_notes_on_testing_model.py --- some_notes_on_testing_model.py | 81 ++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 some_notes_on_testing_model.py diff --git a/some_notes_on_testing_model.py b/some_notes_on_testing_model.py new file mode 100644 index 0000000..0bd5e0c --- /dev/null +++ b/some_notes_on_testing_model.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sat Nov 19 17:13:41 2022 + +@author: wendytsai +""" +#from https://github.com/musikalkemist/pytorchforaudio/tree/main/03%20Making%20predictions + + +import torch +from train import model, download_mnist_datasets + + +class_mapping = [ + "0", + "1", + "2", +] + + +def predict(model, input, target, class_mapping): + model.eval() + with torch.no_grad(): + predictions = model(input) + # Tensor (1, 10) -> [ [0.1, 0.01, ..., 0.6] ] + predicted_index = predictions[0].argmax(0) + predicted = class_mapping[predicted_index] + expected = class_mapping[target] + return predicted, expected + + +if __name__ == "__main__": + # load back the model + model = model() + state_dict = torch.load("model.pth") + model.load_state_dict(state_dict) + + # load MNIST validation dataset + _, validation_data = download_mnist_datasets() + + # get a sample from the validation dataset for inference + input, target = validation_data[0][0], validation_data[0][1] + + # make an inference + predicted, expected = predict(model, input, target, + class_mapping) + print(f"Predicted: '{predicted}', expected: '{expected}'") + + + +# from https://learn.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-train-model + +import matplotlib.pyplot as plt +import numpy as np + +# Function to show the images +def imageshow(img): + img = img / 2 + 0.5 # unnormalize + npimg = img.numpy() + plt.imshow(np.transpose(npimg, (1, 2, 0))) + plt.show() + + +# Function to test the model with a batch of images and show the labels predictions +def testBatch(): + # get batch of images from the test DataLoader + images, labels = next(iter(test_loader)) + + # show all images as one image grid + imageshow(torchvision.utils.make_grid(images)) + + # Let's see what if the model identifiers the labels of those example + outputs = model(images) + + # We got the probability for every 10 labels. The highest (max) probability should be correct label + _, predicted = torch.max(outputs, 1) + + # Let's show the predicted labels on the screen to compare with the real ones + print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] + for j in range(batch_size)))