-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackbone.py
88 lines (68 loc) · 2.51 KB
/
backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
backbone.py
Mar 4 2023
Gabriel Moreira
"""
import torch
from torch import nn
class Convnet(nn.Module):
def __init__(self, in_dim=3, hid_dim=64, out_dim=64):
"""
"""
super().__init__()
self.encoder = nn.Sequential(conv_block(in_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, out_dim))
def forward(self, x):
"""
"""
x = self.encoder(x)
x = nn.MaxPool2d(5)(x)
x = x.view(x.size(0), -1)
return x
def conv_block(in_dim: int, out_dim: int):
return nn.Sequential(nn.Conv2d(in_dim, out_dim, 3, padding=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(),
nn.MaxPool2d(2))
class ResidualMerge(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
x = x + self.model(x)
return x
class ResNLH(nn.Module):
def __init__(self,
n: int,
l: int,
h: int,
out_dim: int,
in_dim: int=3):
super().__init__()
in_layers = [nn.Conv2d(in_dim, h, 2, stride=2),
nn.BatchNorm2d(h),
nn.ReLU()]
for _ in range(n-1):
in_layers.extend([nn.Conv2d(h, h, 2, stride=2),
nn.BatchNorm2d(h),
nn.ReLU()])
self.conv_block = nn.Sequential(*in_layers)
residual_blocks = []
for _ in range(l):
residual_blocks.extend([ResidualMerge(nn.Sequential(nn.Conv2d(h, h, 3, stride=1, padding=1),
nn.BatchNorm2d(h),
nn.ReLU(),
nn.Conv2d(h, h, 3, stride=1, padding=1),
nn.BatchNorm2d(h))),
nn.ReLU()])
self.residuals = nn.Sequential(*residual_blocks)
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
self.fc = nn.Linear(6400, out_dim)
def forward(self, x):
x = self.conv_block(x)
x = self.residuals(x)
x = self.flatten(x)
x = self.fc(x)
return x