-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
130 lines (99 loc) · 4.31 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 15 17:32:03 2020
@author: krishna
"""
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock3x3(nn.Module):
expansion = 1
def __init__(self, inplanes3, planes, stride=1, downsample=None):
super(BasicBlock3x3, self).__init__()
self.conv1 = conv3x3(inplanes3, planes, stride)
self.bn1 = nn.BatchNorm1d(planes)
self.relu = nn.LeakyReLU(negative_slope=0.01,inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class RawNet(nn.Module):
def __init__(self, input_channel, num_classes=1211):
self.inplanes3 = 128
super(RawNet, self).__init__()
self.conv1 = nn.Conv1d(input_channel, 128, kernel_size=3, stride=3, padding=0,
bias=False)
self.bn1 = nn.BatchNorm1d(128)
self.relu = nn.LeakyReLU(negative_slope=0.01,inplace=True)
#############################################################################
self.resblock_1_1 = self._make_layer3(BasicBlock3x3, 128, 1, stride=1)
self.resblock_1_2 = self._make_layer3(BasicBlock3x3, 128, 1, stride=1)
self.maxpool_resblock_1 = nn.MaxPool1d(kernel_size=3, stride=3, padding=0)
#############################################################################
self.resblock_2_1 = self._make_layer3(BasicBlock3x3, 256, 1, stride=1)
self.resblock_2_2 = self._make_layer3(BasicBlock3x3, 256, 1, stride=1)
self.resblock_2_3 = self._make_layer3(BasicBlock3x3, 256, 1, stride=1)
self.resblock_2_4 = self._make_layer3(BasicBlock3x3, 256, 1, stride=1)
self.maxpool_resblock_2 = nn.MaxPool1d(kernel_size=3, stride=3, padding=0)
############################################################################
self.gru = nn.GRU(input_size=256, hidden_size=1024,dropout=0.2,bidirectional=False,batch_first=True)
self.spk_emb = nn.Linear(1024,128)
# self.drop = nn.Dropout(p=0.2)
self.output_layer = nn.Linear(128, num_classes)
def _make_layer3(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes3 != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv1d(self.inplanes3, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes3, planes, stride, downsample))
self.inplanes3 = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes3, planes))
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1(inputs)
out = self.bn1(out)
out = self.relu(out)
## ResBlock-1
out = self.resblock_1_1(out)
out = self.maxpool_resblock_1(out)
out = self.resblock_1_2(out)
out = self.maxpool_resblock_1(out)
##Resblock-2
out = self.resblock_2_1(out)
out = self.maxpool_resblock_2(out)
out = self.resblock_2_2(out)
out = self.maxpool_resblock_2(out)
out = self.resblock_2_3(out)
out = self.maxpool_resblock_2(out)
out = self.resblock_2_4(out)
out = self.maxpool_resblock_2(out)
### GRU
out = out.permute(0,2,1)
out,_ = self.gru(out)
out = out.permute(0,2,1)
spk_embeddings = self.spk_emb(out[:,:,-1])
preds = self.output_layer(spk_embeddings)
return preds,spk_embeddings