-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
68 lines (54 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import math
import matplotlib.pyplot as plt
import numpy as np
from torch_rotation import rotate_three_pass
def rotate_pytorch(image, angle=0, mode='bilinear', padding_mode='zeros'):
shape = image.shape
amat = torch.zeros(shape[0], 2, 3, device=image.device)
if isinstance(angle, float):
amat[:, 0, 0] = math.cos(angle)
amat[:, 0, 1] = -math.sin(angle) * shape[-2] / shape[-1] # (h/w)
amat[:, 1, 0] = math.sin(angle) * shape[-1] / shape[-2] # (w/h)
amat[:, 1, 1] = math.cos(angle)
else:
amat[:, 0, 0] = torch.cos(angle)
amat[:, 0, 1] = -torch.sin(angle) * shape[-2] / shape[-1] # (h/w)
amat[:, 1, 0] = torch.sin(angle) * shape[-1] / shape[-2] # (h/w)
amat[:, 1, 1] = torch.cos(angle)
grid = torch.nn.functional.affine_grid(
theta=amat,
size=shape,
align_corners=False
)
image_rotated = torch.nn.functional.grid_sample(
input=image,
grid=grid,
mode=mode,
padding_mode=padding_mode,
align_corners=False
)
return image_rotated.clamp(0, 1)
if __name__ == '__main__':
img = plt.imread('data/cat.jpg')
img = img.astype(np.float32) / 255
h, w = img.shape[:2]
I = torch.from_numpy(img).permute(2,0,1).unsqueeze(0)
# I = torch.rand(10, 3, 50, 50, 50)
angle = 135 * math.pi / 180
I_rot = rotate_three_pass(I, angle)
I_rot_pt = rotate_pytorch(I, angle, mode="bicubic")
img_rot = I_rot.squeeze(0).permute(1,2,0).numpy()
img_rot_pt = I_rot_pt.squeeze(0).permute(1,2,0).numpy()
err = np.abs(img_rot - img_rot_pt)
plt.figure(figsize=(10,4))
plt.subplot(1,3,1)
plt.imshow(img_rot)
plt.title("Three pass")
plt.subplot(1,3,2)
plt.imshow(img_rot_pt)
plt.title("Bicubic")
plt.subplot(1,3,3)
plt.imshow(err / err.max())
plt.title("Difference")
plt.show()