-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcam_helper.py
69 lines (53 loc) · 2.13 KB
/
cam_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn.functional as F
import numpy as np
import cv2
def get_cam(model, img_tensor, target_layer):
"""
Generate a Class Activation Map (CAM) for a given image and model.
Args:
model (nn.Module): The neural network model.
img_tensor (torch.Tensor): The input image tensor.
target_layer (str): The layer to target for CAM generation.
Returns:
np.ndarray: The CAM mask.
"""
model.eval()
def forward_hook(module, input, output):
activation[0] = output
activation = {}
layer = dict([*model.named_modules()]).get(target_layer, None)
if layer is None:
raise ValueError(f"Layer {target_layer} not found in the model")
hook = layer.register_forward_hook(forward_hook)
with torch.no_grad():
output = model(img_tensor)
hook.remove()
output = output[0]
output = F.relu(output)
weight_softmax_params = list(model.parameters())[-2].data.numpy()
weight_softmax = np.squeeze(weight_softmax_params)
activation = activation[0].squeeze().cpu().data.numpy()
cam = np.zeros(activation.shape[1:], dtype=np.float32)
for i, w in enumerate(weight_softmax):
cam += w * activation[i, :, :]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (img_tensor.shape[-1], img_tensor.shape[-2]))
cam = cam - np.min(cam)
cam = cam / np.max(cam)
return cam
def apply_cam_on_image(img, cam):
"""
Apply the CAM mask on the image with an inverted colormap.
Args:
img (np.ndarray): The original image.
cam (np.ndarray): The CAM mask.
Returns:
np.ndarray: The image with the CAM applied.
"""
cam = cv2.resize(cam, (img.shape[1], img.shape[0])) # Ensure the CAM is resized to the image dimensions
heatmap = cv2.applyColorMap(np.uint8(255 * (1 - cam)), cv2.COLORMAP_TWILIGHT_SHIFTED) # Apply the CAM mask
heatmap = np.float32(heatmap) / 255
cam_img = heatmap + np.float32(img) / 255 # Normalize the image before adding
cam_img = cam_img / np.max(cam_img)
return np.uint8(255 * cam_img)