-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_elu.py
58 lines (41 loc) · 1.92 KB
/
model_elu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
"""model_ELU
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/16WWdQBk1WY8YYwFctpDzYEXNOW7cbzly
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, inchannel, outchannel, kernelsize, ste=1, pad=1, use_act=True):
super(ConvBlock, self).__init__()
self.use_act = use_act
self.conv = nn.Conv2d(inchannel=inchannel, out_channels=outchannel, kernel_size=kernelsize, stride=ste, padding=pad)
self.bn = nn.BatchNorm2d(no)
self.act = nn.ELU()
def forward(self, x):
op = self.bn(self.conv(x))
return self.act(op) if self.use_act else op
class ResBlock(nn.Module):
#Our Implementation of ResBlock is with skip connections, to skip layers and have a better gradient
def __init__(self, inchannel, outchannel, kernelsize):
super(ResBlock, self).__init__()
self.block1 = ConvBlock(inchannel, outchannel, kernelsize)
self.block2 = ConvBlock(inchannel, outchannel, kernelsize, use_act=False)
def forward(self, x):
return x + self.block2(self.block1(x))
class SRResnet(nn.Module):
def __init__(self, inchannel, outchannel, res_layers=16):
super(SRResnet, self).__init__()
self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1)
self.act = nn.ELU()
resl = [ResBlock(output_channels, outputchannel, 3) for i in range(res_layers)]
self.resl = nn.Sequential(_resl)
self.conv2 = ConvBlock(output_channels, output_channels, 3, use_act=False)
self.conv3 = nn.Conv2d(output_channels, inputchannel, kernel_size=3, stride=1, padding=1)
def forward(self, input):
op1 = self.act(self.conv1(input))
op2 = self.conv2(self.resl(op1))
op = self.conv3(torch.add(op1, op2))
return op