Skip to content

Commit

Permalink
Refactored pruning logic and vectorized it (#423)
Browse files Browse the repository at this point in the history
This cleans up the pruning logic a little bit and continues work started in #422.

Changes include:
1) the modifications on various data structures as a result of pruning are now more local to where it matters.
2) vectorized the pruning function and moved it to `utils.py` (similar to `topk()`). Also using a partial now. Vectorization may help us in moving operations to HybridBlocks in the future.
  • Loading branch information
fhieber authored May 31, 2018
1 parent 5db05f4 commit fc440df
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 89 deletions.
76 changes: 23 additions & 53 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,10 +992,9 @@ def __init__(self,
self.pad_dist = mx.nd.full((self.batch_size * self.beam_size, len(self.vocab_target)), val=np.inf,
ctx=self.context)
# These are constants used for manipulation of the beam and scores (particularly for pruning)
self.zeros_array = mx.nd.zeros((self.beam_size,), ctx=self.context, dtype='int32')
self.inf_array_long = mx.nd.full((self.batch_size * self.beam_size,), val=np.inf,
ctx=self.context, dtype='float32')
self.inf_array = mx.nd.slice(self.inf_array_long, begin=(0,), end=(self.beam_size,))
self.zeros_array = mx.nd.zeros((self.batch_size * self.beam_size,), ctx=self.context, dtype='int32')
self.inf_array = mx.nd.full((self.batch_size * self.beam_size, 1), val=np.inf,
ctx=self.context, dtype='float32')

# offset for hypothesis indices in batch decoding
self.offset = np.repeat(np.arange(0, self.batch_size * self.beam_size, self.beam_size), self.beam_size)
Expand All @@ -1006,6 +1005,11 @@ def __init__(self,
offset=self.offset,
use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs

self.prune = partial(utils.prune,
inf_array=self.inf_array[:, 0],
beam_size=self.beam_size,
prune_threshold=self.beam_prune)

logger.info("Translator (%d model(s) beam_size=%d beam_prune=%s beam_search_stop=%s "
"ensemble_mode=%s batch_size=%d buckets_source=%s)",
len(self.models),
Expand Down Expand Up @@ -1304,44 +1308,6 @@ def _combine_predictions(self,
neg_logprobs = self.interpolation_func(probs)
return neg_logprobs, attention_prob_score

def _prune(self,
accumulated_scores: mx.nd.NDArray,
best_word_indices: mx.nd.NDArray,
inactive: mx.nd.NDArray,
finished: mx.nd.NDArray) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]:
"""
Prunes the beam. For each sentence, we find the best-scoring completed hypothesis (if any),
and then remove all hypotheses for that sentence that are outside the beam relative to that
item. Pruned items are marked by setting their entry in `inactive` to 1 and marking them as finished.
The four arguments are updated in place.
Note that after pruning, hypotheses are no longer necessarily sorted until the next call to topk().
TODO: this could be rewritten with batch-level operations.
:param accumulated_scores: The accumulated scores. Shape: (batch * beam, 1).
:param best_word_indices: The row indices indicating the best hypotheses. Shape: (batch * beam).
:param inactive: Marks inactive items in the beam. Shape: (batch * beam).
:param finished: Marks completed items in the beam. Shape: (batch * beam).
"""
for sentno in range(self.batch_size):
rows = slice(sentno * self.beam_size, (sentno + 1) * self.beam_size)
if mx.nd.sum(finished[rows]) > 0:
best_finished_score = mx.nd.min(mx.nd.where(finished[rows],
accumulated_scores[rows, 0],
self.inf_array))

# Find, mark (by setting the score to inf), and remove all hypotheses
# whose score is not within self.beam_prune of the best score
inactive[rows] = mx.nd.cast(accumulated_scores[rows, 0] - best_finished_score > self.beam_prune,
dtype='int32')
accumulated_scores[rows, 0] = mx.nd.where(inactive[rows], self.inf_array, accumulated_scores[rows, 0])
best_word_indices[rows] = mx.nd.where(inactive[rows], self.zeros_array, best_word_indices[rows])

# mark removed ones as finished so they won't block early exiting
finished[rows] = mx.nd.clip(finished[rows] + inactive[rows], 0, 1)
return accumulated_scores, best_word_indices, inactive, finished

def _beam_search(self,
source: mx.nd.NDArray,
source_length: int,
Expand Down Expand Up @@ -1464,15 +1430,17 @@ def _beam_search(self,
# (2) Special treatment for finished and inactive rows. Inactive rows are inf everywhere;
# finished rows are inf everywhere except column zero, which holds the accumulated model score
scores += scores_accumulated
# Items that are finished (but not inactive) get the accumulated score in col 0,
# otherwise infinity for the whole row
# Items that are finished (but not inactive) get their previous accumulated score for the <pad> symbol,
# infinity otherwise.
# pylint: disable=invalid-sequence-index
pad_dist[:, C.PAD_ID] = mx.nd.where(mx.nd.clip(finished - inactive, 0, 1),
scores_accumulated[:, 0],
self.inf_array_long)
self.inf_array[:, 0])
scores = mx.nd.where(finished + inactive, pad_dist, scores)

# (3) Get beam_size winning hypotheses for each sentence block separately. Only look as
# far as the active beam size for each sentence.
# pylint: disable=unsupported-assignment-operation
best_hyp_indices[:], best_word_indices[:], scores_accumulated[:, 0] = self.topk(scores)

# Constraints for constrained decoding are processed sentence by sentence
Expand Down Expand Up @@ -1515,22 +1483,24 @@ def _beam_search(self,

# (6) Prune out low-probability hypotheses. Pruning works by setting entries `inactive`.
if self.beam_prune > 0.0:
scores_accumulated, best_word_indices, inactive, finished = self._prune(scores_accumulated,
best_word_indices,
inactive,
finished)
inactive = self.prune(scores_accumulated, finished)
best_word_indices = mx.nd.where(inactive, self.zeros_array, best_word_indices)
scores_accumulated = mx.nd.where(inactive, self.inf_array, scores_accumulated)
finished_or_inactive = (finished + inactive).clip(0, 1)

# (7) update best hypotheses, their attention lists and lengths (only for non-finished hyps)
# pylint: disable=unsupported-assignment-operation
sequences[:, t] = best_word_indices
attentions[:, t, :] = attention_scores
lengths += mx.nd.cast(1 - mx.nd.expand_dims(finished, axis=1), dtype='float32')
lengths += mx.nd.cast(1 - mx.nd.expand_dims(finished_or_inactive, axis=1), dtype='float32')

# (6) optionally save beam history
if self.store_beam:
unnormalized_scores = mx.nd.where(finished, scores_accumulated * self.length_penalty(lengths - 1),
unnormalized_scores = mx.nd.where(finished_or_inactive,
scores_accumulated * self.length_penalty(lengths - 1),
scores_accumulated)
normalized_scores = mx.nd.where(finished, scores_accumulated,
normalized_scores = mx.nd.where(finished_or_inactive,
scores_accumulated,
scores_accumulated / self.length_penalty(lengths - 1))
for sent in range(self.batch_size):
rows = slice(sent * self.beam_size, (sent + 1) * self.beam_size)
Expand Down Expand Up @@ -1611,7 +1581,7 @@ def _get_best_from_beam(self,
if any(constraints):
# For constrained decoding, select from items that have met all constraints (might not be finished)
unmet = mx.nd.array([c.num_needed() if c is not None else 0 for c in constraints], ctx=self.context)
filtered = mx.nd.where(unmet == 0, seq_scores[:, 0], self.inf_array_long)
filtered = mx.nd.where(unmet == 0, seq_scores, self.inf_array)
filtered = filtered.reshape((self.batch_size, self.beam_size))
best_ids += mx.nd.argmin(filtered, axis=1)

Expand Down
27 changes: 27 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,33 @@ def topk(scores: mx.nd.NDArray,
return best_hyp_indices, best_word_indices, values


def prune(scores: mx.nd.NDArray,
finished: mx.nd.NDArray,
inf_array: mx.nd.NDArray,
beam_size: int,
prune_threshold: float) -> mx.nd.NDArray:
"""
Returns a 0-1 array indicating which hypotheses are inactive based on pruning.
Finished hypotheses that have a score worse than prune_threshold from the best scoring hypotheses
are marked as inactive.
:param scores: Hypotheses scores. Shape: (batch * beam, 1).
:param finished: 0-1 array indicating which hypotheses are finished. Shape: (batch * beam,).
:param inf_array: Auxiliary array filled with infinity. Shape: (batch * beam,).
:param beam_size: Beam size.
:param prune_threshold: Pruning threshold.
:return NDArray of inactive items. Shape(batch * beam,).
"""
scores = scores.reshape((-1, beam_size))
finished = finished.reshape((-1, beam_size))
inf_array = inf_array.reshape((-1, beam_size))

# best finished scores. Shape: (batch, 1)
best_finished_scores = mx.nd.where(finished, scores, inf_array).min(axis=1, keepdims=True)
inactive = mx.nd.cast((scores - best_finished_scores) > prune_threshold, dtype='int32').reshape((-1))
return inactive


def chunks(some_list: List, n: int) -> Iterable[List]:
"""Yield successive n-sized chunks from l."""
for i in range(0, len(some_list), n):
Expand Down
46 changes: 10 additions & 36 deletions test/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# permissions and limitations under the License.

import json
from math import ceil
from unittest.mock import patch, Mock

import mxnet as mx
import numpy as np
import pytest
from math import ceil

import sockeye.constants as C
import sockeye.data_io
import sockeye.inference
from sockeye.utils import SockeyeError
import sockeye.utils

_BOS = 0
_EOS = -1
Expand Down Expand Up @@ -56,14 +56,15 @@ def mock_model():
t_mock = Mock(sockeye.inference.InferenceModel)
t_mock.num_source_factors = num_source_factors
return t_mock
translator.models = [ mock_model() ]

translator.models = [mock_model()]

translator.batch_size = batch_size
translator.beam_size = beam_size
translator.beam_prune = beam_prune
translator.zeros_array = mx.nd.zeros((beam_size,), dtype='int32')
translator.inf_array_long = mx.nd.full((batch_size * beam_size,), val=np.inf, dtype='float32')
translator.inf_array = mx.nd.slice(translator.inf_array_long, begin=(0), end=(beam_size))
translator.inf_array = mx.nd.full((batch_size * beam_size,), val=np.inf, dtype='float32')
translator.inf_array = mx.nd.slice(translator.inf_array, begin=(0), end=(beam_size))
return translator


Expand Down Expand Up @@ -262,7 +263,7 @@ def test_make_input_whitespace_delimiter(delimiter):
sentence_id = 1
translator = mock_translator(num_source_factors=2)
sentence = "foo"
with pytest.raises(SockeyeError) as e:
with pytest.raises(sockeye.utils.SockeyeError) as e:
sockeye.inference.make_input_from_factored_string(sentence_id=sentence_id,
factored_string=sentence,
translator=translator, delimiter=delimiter)
Expand Down Expand Up @@ -309,19 +310,7 @@ def test_make_input_from_multiple_strings(strings):
assert inp.tokens == expected_tokens
assert inp.factors == expected_factors

"""
Test pruning via inference.Translator._beam_prune(). The best score is computed from the best
finished item; all other items whose scores are outside (best_item - threshold) are pruned, which
means their spot in `inactive` is set to 1.
Tests: take in
- accumulated_scores and finished
- a dummy inactive (not read from)
- best_word_indices (maybe dummy?)
and check
- values of finished and invalid
- maybe values of best_word_indices
"""

# batch size, beam size, prune thresh, accumulated scores, finished, expected_inactive
prune_tests = [
# no pruning because nothing is finished
Expand All @@ -341,24 +330,9 @@ def test_make_input_from_multiple_strings(strings):

@pytest.mark.parametrize("batch, beam, prune, scores, finished, expected_inactive", prune_tests)
def test_beam_prune(batch, beam, prune, scores, finished, expected_inactive):
translator = mock_translator(batch, beam, prune)

orig_finished = [x for x in finished]

# these are passed by reference and changed, so create them here
scores = mx.nd.array(scores).expand_dims(axis=1)
inactive = mx.nd.array([0] * (batch * beam), dtype='int32')
best_words = mx.nd.array([10] * (batch * beam), dtype='int32')
finished = mx.nd.array(finished, dtype='int32')
inf_array = mx.nd.full((batch * beam,), val=np.inf)

translator._prune(scores, best_words, inactive, finished)

# Make sure inactive is set as expected
inactive = sockeye.utils.prune(scores, finished, inf_array, beam, prune)
assert inactive.asnumpy().tolist() == expected_inactive

# Ensure that scores for inactive items are set to 'inf'
zeros = mx.nd.zeros((beam * batch,), dtype='float32')
assert mx.nd.where(inactive, scores[:, 0], zeros).asnumpy().tolist() == [np.inf if x == 1 else 0 for x in expected_inactive]

# Inactive items should also be marked as finished
assert finished.asnumpy().tolist() == np.clip(np.array(orig_finished) + np.array(expected_inactive), 0, 1).tolist()

0 comments on commit fc440df

Please sign in to comment.