Skip to content

Commit

Permalink
remove batch_first
Browse files Browse the repository at this point in the history
  • Loading branch information
HawkAaron committed Apr 21, 2018
1 parent d37c9a0 commit cbf3ec4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, vocab_size, num_hidden, num_layers, dropout=0, blank=0, bidir
self.num_hidden = num_hidden
self.num_layers = num_layers
self.vocab_size = vocab_size
self.loss = RNNTLoss(blank)
self.loss = RNNTLoss(blank_label=blank)
self.blank = blank
with self.name_scope():
# acoustic model NOTE only initialize encoder.rnn, we can reuse encoder.decoder
Expand Down
2 changes: 1 addition & 1 deletion model2012.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, vocab_size, num_hidden, num_layers, dropout=0, blank=0, bidir
self.num_hidden = num_hidden
self.num_layers = num_layers
self.vocab_size = vocab_size
self.loss = RNNTLoss(blank)
self.loss = RNNTLoss(blank_label=blank)
self.blank = blank
with self.name_scope():
# acoustic model
Expand Down
18 changes: 14 additions & 4 deletions rnnt_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@
from mxnet import autograd, gluon

class RNNTLoss(gluon.loss.Loss):
def __init__(self, blank_label=0, weight=None, **kwargs):
super(RNNTLoss, self).__init__(weight, batch_axis=0, **kwargs)
def __init__(self, layout='NTC', label_layout='NT', blank_label=0, weight=None, **kwargs):
assert layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
self._layout = layout
self._label_layout = label_layout
batch_axis = label_layout.find('N')
super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
self.blank_label = blank_label

def hybrid_forward(self, F, pred, label,
pred_lengths=None, label_lengths=None):
def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
if self._layout == 'NTC':
pred = F.moveaxis(pred, 0, 2)
if self._batch_axis == 1:
label = F.swapaxes(label, 0, 1)
cpu = mx.cpu()
loss = F.contrib.RNNTLoss(pred.as_in_context(cpu), label.as_in_context(cpu),
pred_lengths.as_in_context(cpu), label_lengths.as_in_context(cpu),
Expand Down

0 comments on commit cbf3ec4

Please sign in to comment.