Skip to content

Commit

Permalink
fix batch axis & dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
HawkAaron committed Apr 15, 2019
1 parent cbf3ec4 commit 52714c8
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions rnnt_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,19 @@
from mxnet import autograd, gluon

class RNNTLoss(gluon.loss.Loss):
def __init__(self, layout='NTC', label_layout='NT', blank_label=0, weight=None, **kwargs):
assert layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
self._layout = layout
self._label_layout = label_layout
batch_axis = label_layout.find('N')
def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
batch_axis = 0 if batch_first else 2
super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
self.batch_first = batch_first
self.blank_label = blank_label

def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
if self._layout == 'NTC':
pred = F.moveaxis(pred, 0, 2)
if self._batch_axis == 1:
label = F.swapaxes(label, 0, 1)
cpu = mx.cpu()
loss = F.contrib.RNNTLoss(pred.as_in_context(cpu), label.as_in_context(cpu),
pred_lengths.as_in_context(cpu), label_lengths.as_in_context(cpu),
blank_label=self.blank_label)
if not self.batch_first:
pred = F.transpose(pred, (2, 0, 1, 3))

loss = F.contrib.RNNTLoss(pred, label.astype('int32', False),
pred_lengths.astype('int32', False),
label_lengths.astype('int32', False),
blank_label=self.blank_label)
return loss

0 comments on commit 52714c8

Please sign in to comment.