-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVisualization.py
64 lines (56 loc) · 2.43 KB
/
Visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from lime import lime_image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
class LIMEExplainer:
def __init__(self, model, device, data_loader):
self.model = model
self.device = device
self.data_loader = data_loader
self.explainer = lime_image.LimeImageExplainer()
def batch_predict(self, images):
# Prepare images as expected by the model
images = torch.stack([TF.to_tensor(TF.to_pil_image(img)) for img in images]).to(self.device)
self.model.eval()
with torch.no_grad():
outputs = self.model(images)
return torch.softmax(outputs, dim=1).cpu().numpy()
def get_explainability_score(self, image, label, num_samples=1000):
explanation = self.explainer.explain_instance(
image,
self.batch_predict, # Use the prediction function defined
top_labels=5,
hide_color=0,
num_samples=num_samples
)
top_label = explanation.top_labels[0]
temp, mask = explanation.get_image_and_mask(top_label, positive_only=True, num_features=5, hide_rest=False)
explainability_score = np.mean(mask)
return explainability_score, top_label
def visualize_explanation(self, image, label):
# Ensure the image is in CPU and convert to PIL for processing
test_image_for_lime = TF.to_pil_image(image.cpu())
# Convert the PIL Image to NumPy array for LIME
test_image_np = np.array(test_image_for_lime)
explanation = self.explainer.explain_instance(
test_image_np,
self.batch_predict,
top_labels=1,
hide_color=0, # This sets the background color of non-explanation areas
num_samples=1000 # Number of perturbations to use for explanation
)
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=True,
num_features=5,
hide_rest=False # Shows the image with only the explanation mask applied
)
plt.figure(figsize=(6, 6))
plt.imshow(mark_boundaries(temp, mask))
plt.title('Highlighted Features by LIME')
plt.axis('off')
plt.show()