-
Notifications
You must be signed in to change notification settings - Fork 2
/
models.py
284 lines (247 loc) · 9.78 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import torch
import numpy as np
from pathlib import Path
import torch.nn as nn
def crnn(inputdim=64, outputdim=527, pretrained_from='balanced.pth', **kwargs):
model = CRNN(inputdim, outputdim, **kwargs)
if pretrained_from:
state = torch.load(pretrained_from,
map_location='cpu')
model.load_state_dict(state, strict=False)
return model
def cnn10(inputdim=64, outputdim=527, pretrained_from='balanced.pth'):
model = CNN10(inputdim, outputdim)
if pretrained_from:
state = torch.load(pretrained_from,
map_location='cpu')
model.load_state_dict(state, strict=False)
return model
def init_weights(m):
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class LinearSoftPool(nn.Module):
"""LinearSoftPool
Linear softmax, takes logits and returns a probability, near to the actual maximum value.
Taken from the paper:
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
https://arxiv.org/abs/1810.09050
"""
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, time_decision):
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
self.pooldim)
class MeanPool(nn.Module):
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, logits, decision):
return torch.mean(decision, dim=self.pooldim)
def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
"""parse_poolingfunction
A heler function to parse any temporal pooling
Pooling is done on dimension 1
:param poolingfunction_name:
:param **kwargs:
"""
poolingfunction_name = poolingfunction_name.lower()
if poolingfunction_name == 'mean':
return MeanPool(pooldim=1)
elif poolingfunction_name == 'linear':
return LinearSoftPool(pooldim=1)
elif poolingfunction_name == 'attention':
return AttentionPool(inputdim=kwargs['inputdim'],
outputdim=kwargs['outputdim'])
class AttentionPool(nn.Module):
"""docstring for AttentionPool"""
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
super().__init__()
self.inputdim = inputdim
self.outputdim = outputdim
self.pooldim = pooldim
self.transform = nn.Linear(inputdim, outputdim)
self.activ = nn.Softmax(dim=self.pooldim)
self.eps = 1e-7
def forward(self, logits, decision):
# Input is (B, T, D)
# B, T , D
w = self.activ(self.transform(logits))
detect = (decision * w).sum(
self.pooldim) / (w.sum(self.pooldim) + self.eps)
# B, T, D
return detect
class Block2D(nn.Module):
def __init__(self, cin, cout, kernel_size=3, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm2d(cin),
nn.Conv2d(cin,
cout,
kernel_size=kernel_size,
padding=padding,
bias=False),
nn.LeakyReLU(inplace=True, negative_slope=0.1))
def forward(self, x):
return self.block(x)
class CRNN(nn.Module):
def __init__(self, inputdim, outputdim, **kwargs):
super().__init__()
self.features = nn.Sequential(
Block2D(1, 32),
nn.LPPool2d(4, (2, 4)),
Block2D(32, 128),
Block2D(128, 128),
nn.LPPool2d(4, (2, 4)),
Block2D(128, 128),
Block2D(128, 128),
nn.LPPool2d(4, (1, 4)),
nn.Dropout(0.3),
)
with torch.no_grad():
rnn_input_dim = self.features(torch.randn(1, 1, 500,
inputdim)).shape
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
self.gru = nn.GRU(rnn_input_dim,
128,
bidirectional=kwargs.get(
'gru_bidirection', True),
batch_first=True)
self.temp_pool = parse_poolingfunction(kwargs.get(
'temppool', 'linear'),
inputdim=256 if kwargs.get(
'gru_bidirection', True) else 128,
outputdim=outputdim)
self.outputlayer = nn.Linear(256 if kwargs.get(
'gru_bidirection', True) else 128,
outputdim)
self.features.apply(init_weights)
self.outputlayer.apply(init_weights)
def forward(self, x, upsample=True):
batch, time, dim = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x, _ = self.gru(x)
decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
if upsample:
decision_time = torch.nn.functional.interpolate(
decision_time.transpose(1, 2),
time,
mode='linear',
align_corners=False).transpose(1, 2)
decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
return decision, decision_time
def forward_stream_vad(self, x, h, upsample=True):
batch, time, dim = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
x, h = self.gru(x, h)
vad_post = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
if upsample:
vad_post = torch.nn.functional.interpolate(
vad_post.transpose(1, 2),
time,
mode='linear',
align_corners=False).transpose(1, 2)
return vad_post, h
class CNN10(nn.Module):
def __init__(self, inputdim, outputdim, **kwargs):
super().__init__()
self.features = nn.Sequential(
Block2D(1, 64),
Block2D(64, 64),
nn.LPPool2d(4, (2, 4)),
Block2D(64, 128),
Block2D(128, 128),
nn.LPPool2d(4, (2, 2)),
Block2D(128, 256),
Block2D(256, 256),
nn.LPPool2d(4, (1, 2)),
Block2D(256, 512),
Block2D(512, 512),
nn.LPPool2d(4, (1, 2)),
nn.Dropout(0.3),
nn.AdaptiveAvgPool2d((None, 1)),
)
self.temp_pool = parse_poolingfunction(kwargs.get(
'temppool', 'attention'),
inputdim=512,
outputdim=outputdim)
self.outputlayer = nn.Linear(512, outputdim)
self.features.apply(init_weights)
self.outputlayer.apply(init_weights)
def forward(self, x, upsample=True):
batch, time, dim = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
if upsample:
decision_time = torch.nn.functional.interpolate(
decision_time.transpose(1, 2),
time,
mode='linear',
align_corners=False).transpose(1, 2)
return decision, decision_time
class CRNN10(nn.Module):
def __init__(self, inputdim, outputdim, **kwargs):
super().__init__()
self._hiddim = kwargs.get('hiddim', 256)
self.features = nn.Sequential(
Block2D(1, 64),
Block2D(64, 64),
nn.LPPool2d(4, (2, 4)),
Block2D(64, 128),
Block2D(128, 128),
nn.LPPool2d(4, (2, 2)),
Block2D(128, 256),
Block2D(256, 256),
nn.LPPool2d(4, (1, 2)),
Block2D(256, 512),
Block2D(512, 512),
nn.LPPool2d(4, (1, 2)),
nn.Dropout(0.3),
nn.AdaptiveAvgPool2d((None, 1)),
)
with torch.no_grad():
rnn_input_dim = self.features(torch.randn(1, 1, 500,
inputdim)).shape
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
self.gru = nn.GRU(rnn_input_dim,
self._hiddim,
bidirectional=True,
batch_first=True)
self.temp_pool = parse_poolingfunction(kwargs.get(
'temppool', 'linear'),
inputdim=self._hiddim*2,
outputdim=outputdim)
self.outputlayer = nn.Linear(self._hiddim*2, outputdim)
self.features.apply(init_weights)
self.outputlayer.apply(init_weights)
def forward(self, x, upsample=True):
batch, time, dim = x.shape
x = x.unsqueeze(1)
x = self.features(x)
x = x.transpose(1, 2).contiguous().flatten(-2)
decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
if upsample:
decision_time = torch.nn.functional.interpolate(
decision_time.transpose(1, 2),
time,
mode='linear',
align_corners=False).transpose(1, 2)
return decision, decision_time