diff --git a/CHANGELOG.md b/CHANGELOG.md index d8ee2abe0..308767422 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.17] +### Changed +- Updated to MXNet 1.2 +- Use of the new LayerNormalization operator to save GPU memory. + ## [1.18.16] ### Fixed - Removed summation of gradient arrays when logging gradients. diff --git a/README.md b/README.md index 97dc8605b..aeebf01c5 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Recent developments and changes are tracked in our [changelog](https://github.co Sockeye requires: - **Python3** -- [MXNet-1.1.0](https://github.com/apache/incubator-mxnet/tree/1.1.0) +- [MXNet-1.2.0](https://github.com/apache/incubator-mxnet/tree/1.2.0) - numpy ## Installation diff --git a/requirements.gpu-cu75.txt b/requirements.gpu-cu75.txt index 8cc9377fa..c571fa279 100644 --- a/requirements.gpu-cu75.txt +++ b/requirements.gpu-cu75.txt @@ -1,4 +1,4 @@ pyyaml -mxnet-cu75==1.1.0 +mxnet-cu75==1.2.0 numpy>=1.12 typing diff --git a/requirements.gpu-cu80.txt b/requirements.gpu-cu80.txt index 8d793a743..19a211ba9 100644 --- a/requirements.gpu-cu80.txt +++ b/requirements.gpu-cu80.txt @@ -1,4 +1,4 @@ pyyaml -mxnet-cu80==1.1.0 +mxnet-cu80==1.2.0 numpy>=1.12 typing diff --git a/requirements.gpu-cu90.txt b/requirements.gpu-cu90.txt index d1c6003f0..4a5d2475a 100644 --- a/requirements.gpu-cu90.txt +++ b/requirements.gpu-cu90.txt @@ -1,4 +1,4 @@ pyyaml -mxnet-cu90==1.1.0 +mxnet-cu90==1.2.0 numpy>=1.12 typing diff --git a/requirements.gpu-cu91.txt b/requirements.gpu-cu91.txt index b366668eb..f448ff66e 100644 --- a/requirements.gpu-cu91.txt +++ b/requirements.gpu-cu91.txt @@ -1,4 +1,4 @@ pyyaml -mxnet-cu91==1.1.0 +mxnet-cu91==1.2.0 numpy>=1.12 typing diff --git a/requirements.txt b/requirements.txt index 205bb3851..e433be1a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pyyaml -mxnet==1.1.0 +mxnet==1.2.0 numpy>=1.12 typing diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 3d1008b20..9c9184dd7 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.16' +__version__ = '1.18.17' diff --git a/sockeye/coverage.py b/sockeye/coverage.py index 5fefd7e10..5b0925cee 100644 --- a/sockeye/coverage.py +++ b/sockeye/coverage.py @@ -227,8 +227,7 @@ def __init__(self, # optional layer normalization self.layer_norm = None if layer_normalization and not self.num_hidden != 1: - self.layer_norm = layers.LayerNormalization(self.num_hidden, - prefix="%snorm" % self.prefix) if layer_normalization else None + self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix) def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ @@ -293,7 +292,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol, updated_coverage = intermediate + attention_hidden + coverage_hidden if self.layer_norm is not None: - updated_coverage = self.layer_norm.normalize(updated_coverage) + updated_coverage = self.layer_norm(data=updated_coverage) # (batch_size, seq_len, coverage_num_hidden) coverage = mx.sym.Activation(data=updated_coverage, diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 95a2e006f..6bcab01e7 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -536,9 +536,9 @@ def __init__(self, # Hidden state parameters self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix) self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) - self.hidden_norm = layers.LayerNormalization(self.num_hidden, - prefix="%shidden_norm" % prefix) \ - if self.config.layer_normalization else None + self.hidden_norm = None + if self.config.layer_normalization: + self.hidden_norm = layers.LayerNormalization(prefix="%shidden_norm" % prefix) def _create_state_init_parameters(self): """ @@ -553,9 +553,8 @@ def _create_state_init_parameters(self): self.init_ws.append(mx.sym.Variable("%senc2decinit_%d_weight" % (self.prefix, state_idx))) self.init_bs.append(mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx))) if self.config.layer_normalization: - self.init_norms.append(layers.LayerNormalization(num_hidden=init_num_hidden, - prefix="%senc2decinit_%d_norm" % ( - self.prefix, state_idx))) + self.init_norms.append(layers.LayerNormalization(prefix="%senc2decinit_%d_norm" % (self.prefix, + state_idx))) def decode_sequence(self, source_encoded: mx.sym.Symbol, @@ -796,7 +795,7 @@ def get_initial_state(self, bias=self.init_bs[state_idx], name="%senc2decinit_%d" % (self.prefix, state_idx)) if self.config.layer_normalization: - init = self.init_norms[state_idx].normalize(init) + init = self.init_norms[state_idx](data=init) init = mx.sym.Activation(data=init, act_type="tanh", name="%senc2dec_inittanh_%d" % (self.prefix, state_idx)) if self.config.state_init_lhuc: @@ -870,7 +869,7 @@ def _hidden_mlp(self, hidden_concat: mx.sym.Symbol, seq_idx: int) -> mx.sym.Symb bias=self.hidden_b, name='%shidden_fc_t%d' % (self.prefix, seq_idx)) if self.config.layer_normalization: - hidden = self.hidden_norm.normalize(hidden) + hidden = self.hidden_norm(data=hidden) # hidden: (batch_size, rnn_num_hidden) hidden = mx.sym.Activation(data=hidden, act_type="tanh", @@ -904,7 +903,7 @@ def _context_gate(self, hidden = gate * mapped_rnn_output + (1 - gate) * mapped_context if self.config.layer_normalization: - hidden = self.hidden_norm.normalize(hidden) + hidden = self.hidden_norm(data=hidden) # hidden: (batch_size, rnn_num_hidden) hidden = mx.sym.Activation(data=hidden, act_type="tanh", diff --git a/sockeye/layers.py b/sockeye/layers.py index 46df810b5..8e13d3888 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -54,62 +54,38 @@ class LayerNormalization: """ Implements Ba et al, Layer Normalization (https://arxiv.org/abs/1607.06450). - :param num_hidden: Number of hidden units of layer to be normalized. :param prefix: Optional prefix of layer name. :param scale: Optional variable for scaling of shape (num_hidden,). Will be created if None. :param shift: Optional variable for shifting of shape (num_hidden,). Will be created if None. :param scale_init: Initial value of scale variable if scale is None. Default 1.0. :param shift_init: Initial value of shift variable if shift is None. Default 0.0. """ - - # TODO(fhieber): this should eventually go to MXNet - def __init__(self, - num_hidden: int, - prefix: Optional[str] = None, + prefix: str = 'layernorm', scale: Optional[mx.sym.Symbol] = None, shift: Optional[mx.sym.Symbol] = None, scale_init: float = 1.0, shift_init: float = 0.0) -> None: - utils.check_condition(num_hidden > 1, - "Layer normalization should only be applied to layers with more than 1 neuron.") self.prefix = prefix - self.scale = scale if scale is not None else mx.sym.Variable('%s_gamma' % prefix, shape=(num_hidden,), + self.scale = scale if scale is not None else mx.sym.Variable('%s_gamma' % prefix, init=mx.init.Constant(value=scale_init)) - self.shift = shift if shift is not None else mx.sym.Variable('%s_beta' % prefix, shape=(num_hidden,), + self.shift = shift if shift is not None else mx.sym.Variable('%s_beta' % prefix, init=mx.init.Constant(value=shift_init)) - @staticmethod - def moments(inputs: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]: - """ - Computes mean and variance of the last dimension of a Symbol. - - :param inputs: Shape: (d0, ..., dn, hidden). - :return: mean, var: Shape: (d0, ..., dn, 1). - """ - mean = mx.sym.mean(data=inputs, axis=-1, keepdims=True) - # TODO(fhieber): MXNet should have this. - var = mx.sym.mean(mx.sym.square(mx.sym.broadcast_minus(inputs, mean)), axis=-1, keepdims=True) - return mean, var - - def normalize(self, inputs: mx.sym.Symbol, eps: float = 0.000001) -> mx.sym.Symbol: + def __call__(self, data: mx.sym.Symbol, eps: float = 1e-06) -> mx.sym.Symbol: """ - Normalizes hidden units of inputs as follows: + Normalizes hidden units of data as follows: - inputs = scale * (inputs - mean) / sqrt(var + eps) + shift + data = scale * (data - mean) / sqrt(var + eps) + shift Normalization is performed over the last dimension of the input data. - :param inputs: Inputs to normalize. Shape: (d0, ..., dn, num_hidden). + :param data: Data to normalize. Shape: (d0, ..., dn, num_hidden). :param eps: Variance epsilon. :return: inputs_norm: Normalized inputs. Shape: (d0, ..., dn, num_hidden). """ - mean, var = self.moments(inputs) - inputs_norm = mx.sym.broadcast_minus(inputs, mean, name='%sinp_minus_mean' % self.prefix) - inputs_norm = mx.sym.broadcast_mul(inputs_norm, mx.sym.rsqrt(var + eps), name='%sinp_norm' % self.prefix) - inputs_norm = mx.sym.broadcast_mul(inputs_norm, self.scale, name='%sinp_norm_scaled' % self.prefix) - inputs_norm = mx.sym.broadcast_add(inputs_norm, self.shift, name='%sinp_norm_scaled_shifted' % self.prefix) - return inputs_norm + return mx.sym.LayerNorm(data=data, gamma=self.scale, beta=self.shift, axis=-1, + eps=eps, output_mean_var=False, name=self.prefix) class LHUC: diff --git a/sockeye/rnn.py b/sockeye/rnn.py index 846ad7054..372ac487d 100644 --- a/sockeye/rnn.py +++ b/sockeye/rnn.py @@ -195,25 +195,15 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) - self._iN = LayerNormalization(num_hidden=num_hidden * 4, - prefix="%si2h" % self._prefix, - scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), - init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), - init=mx.init.Constant(value=norm_shift))) - self._hN = LayerNormalization(num_hidden=num_hidden * 4, - prefix="%sh2h" % self._prefix, - scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), - init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), - init=mx.init.Constant(value=norm_shift))) - self._cN = LayerNormalization(num_hidden=num_hidden, - prefix="%sc" % self._prefix, - scale=self.params.get('c_scale', shape=(num_hidden,), - init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('c_shift', shape=(num_hidden,), - init=mx.init.Constant(value=norm_shift))) - self._shape_fix = None + self._iN = LayerNormalization(prefix="%si2h" % self._prefix, + scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) + self._hN = LayerNormalization(prefix="%sh2h" % self._prefix, + scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) + self._cN = LayerNormalization(prefix="%sc" % self._prefix, + scale=self.params.get('c_scale', shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('c_shift', shape=(num_hidden,), init=mx.init.Constant(value=norm_shift))) def __call__(self, inputs, states): self._counter += 1 @@ -221,14 +211,10 @@ def __call__(self, inputs, states): i2h = mx.sym.FullyConnected(data=inputs, weight=self._iW, bias=self._iB, num_hidden=self._num_hidden * 4, name='%si2h' % name) - if self._counter == 0: - self._shape_fix = mx.sym.zeros_like(i2h) - else: - assert self._shape_fix is not None h2h = mx.sym.FullyConnected(data=states[0], weight=self._hW, bias=self._hB, num_hidden=self._num_hidden * 4, name='%sh2h' % name) - gates = self._iN.normalize(i2h) + self._hN.normalize(self._shape_fix + h2h) + gates = self._iN(data=i2h) + self._hN(data=h2h + mx.sym.zeros_like(i2h)) # pylint: disable=unbalanced-tuple-unpacking in_gate, forget_gate, in_transform, out_gate = mx.sym.split(gates, num_outputs=4, @@ -245,8 +231,7 @@ def __call__(self, inputs, states): next_c = mx.sym._internal._plus(forget_gate * states[1], in_gate * in_transform, name='%sstate' % name) next_h = mx.sym._internal._mul(out_gate, - mx.sym.Activation(self._cN.normalize(next_c), - act_type="tanh"), + mx.sym.Activation(self._cN(data=next_c), act_type="tanh"), name='%sout' % name) return next_h, [next_h, next_c] @@ -274,12 +259,12 @@ def __init__(self, super(LayerNormPerGateLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) self._norm_layers = list() # type: List[LayerNormalization] for name in ['i', 'f', 'c', 'o', 's']: - scale = self.params.get('%s_shift' % name, shape=(num_hidden,), + scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) - shift = self.params.get('%s_scale' % name, shape=(num_hidden,), + shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale if name != "f" else forget_bias)) self._norm_layers.append( - LayerNormalization(num_hidden, prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) + LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) def __call__(self, inputs, states): self._counter += 1 @@ -295,10 +280,10 @@ def __call__(self, inputs, states): in_gate, forget_gate, in_transform, out_gate = mx.sym.split( gates, num_outputs=4, name="%sslice" % name) - in_gate = self._norm_layers[0].normalize(in_gate) - forget_gate = self._norm_layers[1].normalize(forget_gate) - in_transform = self._norm_layers[2].normalize(in_transform) - out_gate = self._norm_layers[3].normalize(out_gate) + in_gate = self._norm_layers[0](data=in_gate) + forget_gate = self._norm_layers[1](data=forget_gate) + in_transform = self._norm_layers[2](data=in_transform) + out_gate = self._norm_layers[3](data=out_gate) in_gate = mx.sym.Activation(in_gate, act_type="sigmoid", name='%si' % name) @@ -311,7 +296,7 @@ def __call__(self, inputs, states): next_c = mx.sym._internal._plus(forget_gate * states[1], in_gate * in_transform, name='%sstate' % name) next_h = mx.sym._internal._mul(out_gate, - mx.sym.Activation(self._norm_layers[4].normalize(next_c), act_type="tanh"), + mx.sym.Activation(self._norm_layers[4].__call__(next_c), act_type="tanh"), name='%sout' % name) return next_h, [next_h, next_c] @@ -392,19 +377,12 @@ def __init__(self, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params) - self._iN = LayerNormalization(num_hidden=num_hidden * 3, - prefix="%si2h" % self._prefix, - scale=self.params.get('i2h_scale', shape=(num_hidden * 3,), - init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('i2h_shift', shape=(num_hidden * 3,), - init=mx.init.Constant(value=norm_shift))) - self._hN = LayerNormalization(num_hidden=num_hidden * 3, - prefix="%sh2h" % self._prefix, - scale=self.params.get('h2h_scale', shape=(num_hidden * 3,), - init=mx.init.Constant(value=norm_scale)), - shift=self.params.get('h2h_shift', shape=(num_hidden * 3,), - init=mx.init.Constant(value=norm_shift))) - self._shape_fix = None + self._iN = LayerNormalization(prefix="%si2h" % self._prefix, + scale=self.params.get('i2h_scale', init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('i2h_shift', init=mx.init.Constant(value=norm_shift))) + self._hN = LayerNormalization(prefix="%sh2h" % self._prefix, + scale=self.params.get('h2h_scale', init=mx.init.Constant(value=norm_scale)), + shift=self.params.get('h2h_shift', init=mx.init.Constant(value=norm_shift))) def __call__(self, inputs, states): self._counter += 1 @@ -423,13 +401,9 @@ def __call__(self, inputs, states): bias=self._hB, num_hidden=self._num_hidden * 3, name="%s_h2h" % name) - if self._counter == 0: - self._shape_fix = mx.sym.zeros_like(i2h) - else: - assert self._shape_fix is not None - i2h = self._iN.normalize(i2h) - h2h = self._hN.normalize(self._shape_fix + h2h) + i2h = self._iN(data=i2h) + h2h = self._hN(data=h2h) # pylint: disable=unbalanced-tuple-unpacking i2h_r, i2h_z, i2h = mx.sym.split(i2h, num_outputs=3, name="%s_i2h_slice" % name) @@ -470,10 +444,9 @@ def __init__(self, super(LayerNormPerGateGRUCell, self).__init__(num_hidden, prefix, params) self._norm_layers = list() # type: List[LayerNormalization] for name in ['r', 'z', 'o']: - scale = self.params.get('%s_shift' % name, shape=(num_hidden,), init=mx.init.Constant(value=norm_shift)) - shift = self.params.get('%s_scale' % name, shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)) - self._norm_layers.append( - LayerNormalization(num_hidden, prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) + scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) + shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale)) + self._norm_layers.append(LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift)) def __call__(self, inputs, states): self._counter += 1 @@ -497,12 +470,12 @@ def __call__(self, inputs, states): i2h_r, i2h_z, i2h = mx.sym.split(i2h, num_outputs=3, name="%s_i2h_slice" % name) h2h_r, h2h_z, h2h = mx.sym.split(h2h, num_outputs=3, name="%s_h2h_slice" % name) - reset_gate = mx.sym.Activation(self._norm_layers[0].normalize(i2h_r + h2h_r), + reset_gate = mx.sym.Activation(self._norm_layers[0](data=i2h_r + h2h_r), act_type="sigmoid", name="%s_r_act" % name) - update_gate = mx.sym.Activation(self._norm_layers[1].normalize(i2h_z + h2h_z), + update_gate = mx.sym.Activation(self._norm_layers[1](data=i2h_z + h2h_z), act_type="sigmoid", name="%s_z_act" % name) - next_h_tmp = mx.sym.Activation(self._norm_layers[2].normalize(i2h + reset_gate * h2h), + next_h_tmp = mx.sym.Activation(self._norm_layers[2](data=i2h + reset_gate * h2h), act_type="tanh", name="%s_h_act" % name) next_h = mx.sym._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h, diff --git a/sockeye/rnn_attention.py b/sockeye/rnn_attention.py index 0415cfa46..88c6e6f20 100644 --- a/sockeye/rnn_attention.py +++ b/sockeye/rnn_attention.py @@ -643,8 +643,9 @@ def __init__(self, # input (coverage) to hidden self.att_c2h_weight = None # layer normalization - self._ln = layers.LayerNormalization(num_hidden=num_hidden, - prefix="%snorm" % self.prefix) if layer_normalization else None + self._ln = None + if layer_normalization: + self._ln = layers.LayerNormalization(prefix="%snorm" % self.prefix) def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ @@ -708,7 +709,7 @@ def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionSta name="%squery_plus_input" % self.prefix) if self._ln is not None: - attention_hidden = self._ln.normalize(attention_hidden) + attention_hidden = self._ln(data=attention_hidden) # (batch_size, seq_len, attention_num_hidden) attention_hidden = mx.sym.Activation(attention_hidden, act_type="tanh", diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 148757abf..d9e3b9f97 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -208,7 +208,7 @@ def __init__(self, self.prefix = prefix self.layer_norm = None if "n" in sequence: - self.layer_norm = layers.LayerNormalization(num_hidden=self.num_hidden, prefix="%snorm" % self.prefix) + self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix) def __call__(self, data: mx.sym.Symbol, @@ -232,7 +232,7 @@ def __call__(self, data = mx.sym._internal._plus(data, prev, name="%sresidual" % self.prefix) elif step == "n": - data = self.layer_norm.normalize(data) + data = self.layer_norm(data=data) elif step == "d": if self.dropout > 0.0: diff --git a/test/unit/test_coverage.py b/test/unit/test_coverage.py index e939ae833..3168a5d5b 100644 --- a/test/unit/test_coverage.py +++ b/test/unit/test_coverage.py @@ -81,11 +81,10 @@ def _test_activation_coverage(act_type): executor.arg_dict["prev_coverage"][:] = prev_coverage_data executor.arg_dict["attention_scores"][:] = attention_scores_data result = executor.forward() - # this is needed to modulate the 0 input. The output changes according to the activation type used. - activation = mx.sym.Activation(name="activation", act_type=act_type) - modulated = activation.eval(ctx=mx.cpu(), activation_data=mx.nd.zeros((1, 1)))[0].asnumpy() new_coverage = result[0].asnumpy() assert new_coverage.shape == prev_coverage_shape + # this is needed to modulate the 0 input. The output changes according to the activation type used. + modulated = mx.nd.Activation(mx.nd.zeros((1, 1)), act_type=act_type).asnumpy() assert (np.sum(np.sum(new_coverage == modulated, axis=2) != 0, axis=1) == source_length_data).all() @@ -143,5 +142,4 @@ def _patch_sequence_mask(test): with patch.object(mx, 'sym', wraps=mx.sym) as mxnet_mock: # Patch Sequence Mask to use ones for padding. mxnet_mock.SequenceMask = _mask_with_one - test() diff --git a/test/unit/test_layers.py b/test/unit/test_layers.py index 32561aee6..c7d640d29 100644 --- a/test/unit/test_layers.py +++ b/test/unit/test_layers.py @@ -26,25 +26,16 @@ def test_layer_normalization(): x_nd = mx.nd.uniform(0, 10, (batch_size, other_dim, num_hidden)) x_np = x_nd.asnumpy() - ln = sockeye.layers.LayerNormalization(num_hidden, prefix="") - - # test moments - sym = mx.sym.Group(ln.moments(x)) - mean, var = sym.eval(x=x_nd) + ln = sockeye.layers.LayerNormalization(prefix="") expected_mean = np.mean(x_np, axis=-1, keepdims=True) expected_var = np.var(x_np, axis=-1, keepdims=True) + expected_norm = (x_np - expected_mean) / np.sqrt(expected_var) - assert np.isclose(mean.asnumpy(), expected_mean).all() - assert np.isclose(var.asnumpy(), expected_var).all() - - sym = ln.normalize(x) - norm = sym.eval(x=x_nd, + norm = ln(x).eval(x=x_nd, _gamma=mx.nd.ones((num_hidden,)), _beta=mx.nd.zeros((num_hidden,)))[0] - expected_norm = (x_np - expected_mean) / np.sqrt(expected_var) - assert np.isclose(norm.asnumpy(), expected_norm, atol=1.e-6).all() diff --git a/test/unit/test_rnn.py b/test/unit/test_rnn.py index f39103b19..b9efc0406 100644 --- a/test/unit/test_rnn.py +++ b/test/unit/test_rnn.py @@ -47,7 +47,6 @@ def test_ln_cell(cell, expected_param_keys): inputs = [mx.sym.Variable('rnn_t%d_data' % i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - print(sorted(cell.params._params.keys())) assert sorted(cell.params._params.keys()) == expected_param_keys assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']