Skip to content

Commit

Permalink
Update to MXNet 1.2 (#388)
Browse files Browse the repository at this point in the history
This PR updates Sockeye to MXNet 1.2 which was released May 21st 2018.

Core change to Sockeye is the use of the new LayerNormalization operator which reduces GPU memory usage. It uses the same set of parameters existing models are compatible, but running sockeye now requires mxnet 1.2.
  • Loading branch information
fhieber authored May 24, 2018
1 parent af59303 commit 8835331
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 134 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.17]
### Changed
- Updated to MXNet 1.2
- Use of the new LayerNormalization operator to save GPU memory.

## [1.18.16]
### Fixed
- Removed summation of gradient arrays when logging gradients.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Recent developments and changes are tracked in our [changelog](https://github.co

Sockeye requires:
- **Python3**
- [MXNet-1.1.0](https://github.com/apache/incubator-mxnet/tree/1.1.0)
- [MXNet-1.2.0](https://github.com/apache/incubator-mxnet/tree/1.2.0)
- numpy

## Installation
Expand Down
2 changes: 1 addition & 1 deletion requirements.gpu-cu75.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml
mxnet-cu75==1.1.0
mxnet-cu75==1.2.0
numpy>=1.12
typing
2 changes: 1 addition & 1 deletion requirements.gpu-cu80.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml
mxnet-cu80==1.1.0
mxnet-cu80==1.2.0
numpy>=1.12
typing
2 changes: 1 addition & 1 deletion requirements.gpu-cu90.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml
mxnet-cu90==1.1.0
mxnet-cu90==1.2.0
numpy>=1.12
typing
2 changes: 1 addition & 1 deletion requirements.gpu-cu91.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml
mxnet-cu91==1.1.0
mxnet-cu91==1.2.0
numpy>=1.12
typing
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pyyaml
mxnet==1.1.0
mxnet==1.2.0
numpy>=1.12
typing
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.16'
__version__ = '1.18.17'
5 changes: 2 additions & 3 deletions sockeye/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ def __init__(self,
# optional layer normalization
self.layer_norm = None
if layer_normalization and not self.num_hidden != 1:
self.layer_norm = layers.LayerNormalization(self.num_hidden,
prefix="%snorm" % self.prefix) if layer_normalization else None
self.layer_norm = layers.LayerNormalization(prefix="%snorm" % self.prefix)

def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable:
"""
Expand Down Expand Up @@ -293,7 +292,7 @@ def update_coverage(prev_hidden: mx.sym.Symbol,
updated_coverage = intermediate + attention_hidden + coverage_hidden

if self.layer_norm is not None:
updated_coverage = self.layer_norm.normalize(updated_coverage)
updated_coverage = self.layer_norm(data=updated_coverage)

# (batch_size, seq_len, coverage_num_hidden)
coverage = mx.sym.Activation(data=updated_coverage,
Expand Down
17 changes: 8 additions & 9 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ def __init__(self,
# Hidden state parameters
self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix)
self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix)
self.hidden_norm = layers.LayerNormalization(self.num_hidden,
prefix="%shidden_norm" % prefix) \
if self.config.layer_normalization else None
self.hidden_norm = None
if self.config.layer_normalization:
self.hidden_norm = layers.LayerNormalization(prefix="%shidden_norm" % prefix)

def _create_state_init_parameters(self):
"""
Expand All @@ -553,9 +553,8 @@ def _create_state_init_parameters(self):
self.init_ws.append(mx.sym.Variable("%senc2decinit_%d_weight" % (self.prefix, state_idx)))
self.init_bs.append(mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx)))
if self.config.layer_normalization:
self.init_norms.append(layers.LayerNormalization(num_hidden=init_num_hidden,
prefix="%senc2decinit_%d_norm" % (
self.prefix, state_idx)))
self.init_norms.append(layers.LayerNormalization(prefix="%senc2decinit_%d_norm" % (self.prefix,
state_idx)))

def decode_sequence(self,
source_encoded: mx.sym.Symbol,
Expand Down Expand Up @@ -796,7 +795,7 @@ def get_initial_state(self,
bias=self.init_bs[state_idx],
name="%senc2decinit_%d" % (self.prefix, state_idx))
if self.config.layer_normalization:
init = self.init_norms[state_idx].normalize(init)
init = self.init_norms[state_idx](data=init)
init = mx.sym.Activation(data=init, act_type="tanh",
name="%senc2dec_inittanh_%d" % (self.prefix, state_idx))
if self.config.state_init_lhuc:
Expand Down Expand Up @@ -870,7 +869,7 @@ def _hidden_mlp(self, hidden_concat: mx.sym.Symbol, seq_idx: int) -> mx.sym.Symb
bias=self.hidden_b,
name='%shidden_fc_t%d' % (self.prefix, seq_idx))
if self.config.layer_normalization:
hidden = self.hidden_norm.normalize(hidden)
hidden = self.hidden_norm(data=hidden)

# hidden: (batch_size, rnn_num_hidden)
hidden = mx.sym.Activation(data=hidden, act_type="tanh",
Expand Down Expand Up @@ -904,7 +903,7 @@ def _context_gate(self,
hidden = gate * mapped_rnn_output + (1 - gate) * mapped_context

if self.config.layer_normalization:
hidden = self.hidden_norm.normalize(hidden)
hidden = self.hidden_norm(data=hidden)

# hidden: (batch_size, rnn_num_hidden)
hidden = mx.sym.Activation(data=hidden, act_type="tanh",
Expand Down
42 changes: 9 additions & 33 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,62 +54,38 @@ class LayerNormalization:
"""
Implements Ba et al, Layer Normalization (https://arxiv.org/abs/1607.06450).
:param num_hidden: Number of hidden units of layer to be normalized.
:param prefix: Optional prefix of layer name.
:param scale: Optional variable for scaling of shape (num_hidden,). Will be created if None.
:param shift: Optional variable for shifting of shape (num_hidden,). Will be created if None.
:param scale_init: Initial value of scale variable if scale is None. Default 1.0.
:param shift_init: Initial value of shift variable if shift is None. Default 0.0.
"""

# TODO(fhieber): this should eventually go to MXNet

def __init__(self,
num_hidden: int,
prefix: Optional[str] = None,
prefix: str = 'layernorm',
scale: Optional[mx.sym.Symbol] = None,
shift: Optional[mx.sym.Symbol] = None,
scale_init: float = 1.0,
shift_init: float = 0.0) -> None:
utils.check_condition(num_hidden > 1,
"Layer normalization should only be applied to layers with more than 1 neuron.")
self.prefix = prefix
self.scale = scale if scale is not None else mx.sym.Variable('%s_gamma' % prefix, shape=(num_hidden,),
self.scale = scale if scale is not None else mx.sym.Variable('%s_gamma' % prefix,
init=mx.init.Constant(value=scale_init))
self.shift = shift if shift is not None else mx.sym.Variable('%s_beta' % prefix, shape=(num_hidden,),
self.shift = shift if shift is not None else mx.sym.Variable('%s_beta' % prefix,
init=mx.init.Constant(value=shift_init))

@staticmethod
def moments(inputs: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]:
"""
Computes mean and variance of the last dimension of a Symbol.
:param inputs: Shape: (d0, ..., dn, hidden).
:return: mean, var: Shape: (d0, ..., dn, 1).
"""
mean = mx.sym.mean(data=inputs, axis=-1, keepdims=True)
# TODO(fhieber): MXNet should have this.
var = mx.sym.mean(mx.sym.square(mx.sym.broadcast_minus(inputs, mean)), axis=-1, keepdims=True)
return mean, var

def normalize(self, inputs: mx.sym.Symbol, eps: float = 0.000001) -> mx.sym.Symbol:
def __call__(self, data: mx.sym.Symbol, eps: float = 1e-06) -> mx.sym.Symbol:
"""
Normalizes hidden units of inputs as follows:
Normalizes hidden units of data as follows:
inputs = scale * (inputs - mean) / sqrt(var + eps) + shift
data = scale * (data - mean) / sqrt(var + eps) + shift
Normalization is performed over the last dimension of the input data.
:param inputs: Inputs to normalize. Shape: (d0, ..., dn, num_hidden).
:param data: Data to normalize. Shape: (d0, ..., dn, num_hidden).
:param eps: Variance epsilon.
:return: inputs_norm: Normalized inputs. Shape: (d0, ..., dn, num_hidden).
"""
mean, var = self.moments(inputs)
inputs_norm = mx.sym.broadcast_minus(inputs, mean, name='%sinp_minus_mean' % self.prefix)
inputs_norm = mx.sym.broadcast_mul(inputs_norm, mx.sym.rsqrt(var + eps), name='%sinp_norm' % self.prefix)
inputs_norm = mx.sym.broadcast_mul(inputs_norm, self.scale, name='%sinp_norm_scaled' % self.prefix)
inputs_norm = mx.sym.broadcast_add(inputs_norm, self.shift, name='%sinp_norm_scaled_shifted' % self.prefix)
return inputs_norm
return mx.sym.LayerNorm(data=data, gamma=self.scale, beta=self.shift, axis=-1,
eps=eps, output_mean_var=False, name=self.prefix)


class LHUC:
Expand Down
Loading

0 comments on commit 8835331

Please sign in to comment.