-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
37 lines (26 loc) · 886 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from cbam import CAM, SAM, CBAM
def test_CAM(input_tensor):
c,h,w = input_tensor.shape[1:]
cam = CAM(in_channels=c)
return cam(input_tensor)
def test_SAM(input_tensor):
c,h,w = input_tensor.shape[1:]
sam = SAM()
return sam(input_tensor)
def test_CBAM(input_tensor):
c,h,w = input_tensor.shape[1:]
cbam = CBAM(in_channels=c)
return cbam(input_tensor)
def main():
batch_size, c, h, w = 8, 64, 10, 10
print('Initialize tensor with shape [{},{},{},{}]'.format(batch_size, c, h, w))
dummy_tensor = torch.autograd.Variable(torch.ones(batch_size, c, h, w))
out = test_CAM(dummy_tensor)
print('Shape after CAM: ', out.shape)
out = test_SAM(dummy_tensor)
print('Shape after SAM: ', out.shape)
out = test_CBAM(dummy_tensor)
print('Shape after CBAM: ', out.shape)
if __name__ == "__main__":
main()