Skip to content

Commit

Permalink
Interleaved Multi-head Attention Operators (#884)
Browse files Browse the repository at this point in the history
Replaced batched dot product in multi-head attention with interleaved_matmul attention operators to improve performance. Also changes the batch-major data to time-major format while in the model to comply with the new operator requirements.
  • Loading branch information
blchu authored Oct 2, 2020
1 parent 07a5737 commit 9014405
Show file tree
Hide file tree
Showing 18 changed files with 108 additions and 208 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [2.2.0]

### Changed

- Replaced multi-head attention with [interleaved_matmul_encdec](https://github.com/apache/incubator-mxnet/pull/16408) operators, which removes previously needed transposes and improves performance.

- Beam search states and model layers now assume time-major format.

## [2.1.26]

Expand Down
12 changes: 6 additions & 6 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ include pylintrc
include .flake8
include typechecked-files
include test/data/config_with_missing_attributes.yaml
include test/data/model_2.1.x/config
include test/data/model_2.1.x/params.best
include test/data/model_2.1.x/model_input
include test/data/model_2.1.x/vocab*
include test/data/model_2.1.x/version
include test/data/model_2.1.x/README.md
include test/data/model_2.2.x/config
include test/data/model_2.2.x/params.best
include test/data/model_2.2.x/model_input
include test/data/model_2.2.x/vocab*
include test/data/model_2.2.x/version
include test/data/model_2.2.x/README.md
include sockeye/git_version.py
include *.bib
recursive-include .github *
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.1.26'
__version__ = '2.2.0'
11 changes: 6 additions & 5 deletions sockeye/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,17 +417,17 @@ def _repeat_states(states: List, beam_size: int, state_structure: List) -> List:
assert len(states) == len(flat_structure), "Number of states do not match the defined state structure"
for state, state_format in zip(states, flat_structure):
if state_format == C.STEP_STATE or state_format == C.BIAS_STATE:
# Steps and source_bias have batch dimension on axis 0
repeat_axis = 0
elif state_format == C.DECODER_STATE or state_format == C.ENCODER_STATE:
# TODO: Change repeat axis to 1 when interleaved multihead attention is implemented
repeat_axis = 0
# Decoder and encoder layer states have batch dimension on axis 1
repeat_axis = 1
else:
raise ValueError("Provided state format %s not recognized." % state_format)
repeated_state = state.repeat(repeats=beam_size, axis=repeat_axis)
repeated_states.append(repeated_state)
return repeated_states


class SortStates(mx.gluon.HybridBlock):

def __init__(self, state_structure, prefix):
Expand All @@ -439,10 +439,11 @@ def hybrid_forward(self, F, best_hyp_indices, *states):
assert len(states) == len(self.flat_structure), "Number of states do not match the defined state structure"
for state, state_format in zip(states, self.flat_structure):
if state_format == C.STEP_STATE or state_format == C.BIAS_STATE:
# Steps and source_bias have batch dimension on axis 0
sorted_state = F.take(state, best_hyp_indices)
elif state_format == C.DECODER_STATE:
# TODO: Change take axis to 1 when interleaved multihead attention is implemented
sorted_state = F.take(state, best_hyp_indices)
# Decoder and encoder layer states have batch dimension on axis 1
sorted_state = F.take(state, best_hyp_indices, axis=1)
elif state_format == C.ENCODER_STATE:
# No need for takes on encoder layer states
sorted_state = state
Expand Down
27 changes: 13 additions & 14 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def state_structure(self) -> str:
"""
structure = ''
if self.inference_only:
structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers * 2
structure += C.STEP_STATE + C.BIAS_STATE + C.ENCODER_STATE * self.config.num_layers
else:
structure += C.STEP_STATE + C.ENCODER_STATE + C.BIAS_STATE

Expand Down Expand Up @@ -197,13 +197,11 @@ def init_state_from_encoder(self,
states = [step, source_mask]

for layer in self.layers:
encoder_attention_keys, encoder_attention_values = \
layer.enc_attention.project_and_isolate_heads(mx.nd, encoder_outputs)
states.append(encoder_attention_keys)
states.append(encoder_attention_values)
enc_att_kv = layer.enc_attention.ff_kv(encoder_outputs)
states.append(mx.nd.transpose(enc_att_kv, axes=(1, 0, 2)))
else:
# NO encoder projection caching
states = [step, encoder_outputs, source_mask]
states = [step, mx.nd.transpose(encoder_outputs, axes=(1, 0, 2)), source_mask]

batch_size = encoder_outputs.shape[0]
dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size),
Expand Down Expand Up @@ -271,7 +269,7 @@ def forward(self, step_input, states):

if self.inference_only:
# pass in cached encoder states
encoder_attention_keys_values = states[2:2 + self.config.num_layers * 2]
encoder_attention_keys_values = states[2:2 + self.config.num_layers]
new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states
else:
encoder_outputs = states[1]
Expand All @@ -288,14 +286,13 @@ def hybrid_forward(self, F, step_input, states):
if self.inference_only:
steps, source_mask, *other = states
source_encoded = None # use constant pre-computed key value projections from the states
enc_att_kv = other[:self.config.num_layers * 2]
enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)]
autoregr_states = other[self.config.num_layers * 2:]
enc_att_kv = other[:self.config.num_layers]
autoregr_states = other[self.config.num_layers:]
else:
if any(layer.needs_mask for layer in self.layers):
mask = self.autoregressive_bias(step_input) # mask: (1, length, length)
steps, source_encoded, source_mask, *autoregr_states = states
enc_att_kv = [(None, None) for _ in range(self.config.num_layers)]
enc_att_kv = [None for _ in range(self.config.num_layers)]

if any(layer.num_state_tensors > 1 for layer in self.layers):
# separates autoregressive states by layer
Expand All @@ -307,23 +304,25 @@ def hybrid_forward(self, F, step_input, states):

# target: (batch_size, length, model_size)
target = self.pos_embedding(step_input, steps)
# (length, batch_size, model_size)
target = F.transpose(target, axes=(1, 0, 2))

if self.config.dropout_prepost > 0.0:
target = F.Dropout(data=target, p=self.config.dropout_prepost)

new_autoregr_states = []
for layer, layer_autoregr_state, (enc_att_k, enc_att_v) in zip(self.layers, autoregr_states, enc_att_kv):
for layer, layer_autoregr_state, layer_enc_att_kv in zip(self.layers, autoregr_states, enc_att_kv):
target, new_layer_autoregr_state = layer(target,
mask,
source_encoded,
source_mask,
layer_autoregr_state,
enc_att_k, enc_att_v)
layer_enc_att_kv)

new_autoregr_states += [*new_layer_autoregr_state]
# NOTE: the list expansion is needed in order to handle both a tuple (of Symbols) and a Symbol as a new state

target = self.final_process(target, None)
target = F.transpose(target, axes=(1, 0, 2))

return target, new_autoregr_states

Expand Down
2 changes: 2 additions & 0 deletions sockeye/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,13 @@ def hybrid_forward(self, F, data, valid_length):

# (batch_size * heads, 1, seq_len)
bias = F.expand_dims(self.valid_length_mask(data, valid_length), axis=1)
data = F.transpose(data, axes=(1, 0, 2))

for block in self.layers:
data = block(data, bias)

data = self.final_process(data, None)
data = F.transpose(data, axes=(1, 0, 2))
return data, valid_length

def get_num_hidden(self) -> int:
Expand Down
Loading

0 comments on commit 9014405

Please sign in to comment.