From fc440df9cccce7d933b520f320a546cecf82325b Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Thu, 31 May 2018 09:27:23 +0200 Subject: [PATCH] Refactored pruning logic and vectorized it (#423) This cleans up the pruning logic a little bit and continues work started in #422. Changes include: 1) the modifications on various data structures as a result of pruning are now more local to where it matters. 2) vectorized the pruning function and moved it to `utils.py` (similar to `topk()`). Also using a partial now. Vectorization may help us in moving operations to HybridBlocks in the future. --- sockeye/inference.py | 76 +++++++++++-------------------------- sockeye/utils.py | 27 +++++++++++++ test/unit/test_inference.py | 46 +++++----------------- 3 files changed, 60 insertions(+), 89 deletions(-) diff --git a/sockeye/inference.py b/sockeye/inference.py index e21b9adda..e64193cdb 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -992,10 +992,9 @@ def __init__(self, self.pad_dist = mx.nd.full((self.batch_size * self.beam_size, len(self.vocab_target)), val=np.inf, ctx=self.context) # These are constants used for manipulation of the beam and scores (particularly for pruning) - self.zeros_array = mx.nd.zeros((self.beam_size,), ctx=self.context, dtype='int32') - self.inf_array_long = mx.nd.full((self.batch_size * self.beam_size,), val=np.inf, - ctx=self.context, dtype='float32') - self.inf_array = mx.nd.slice(self.inf_array_long, begin=(0,), end=(self.beam_size,)) + self.zeros_array = mx.nd.zeros((self.batch_size * self.beam_size,), ctx=self.context, dtype='int32') + self.inf_array = mx.nd.full((self.batch_size * self.beam_size, 1), val=np.inf, + ctx=self.context, dtype='float32') # offset for hypothesis indices in batch decoding self.offset = np.repeat(np.arange(0, self.batch_size * self.beam_size, self.beam_size), self.beam_size) @@ -1006,6 +1005,11 @@ def __init__(self, offset=self.offset, use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs + self.prune = partial(utils.prune, + inf_array=self.inf_array[:, 0], + beam_size=self.beam_size, + prune_threshold=self.beam_prune) + logger.info("Translator (%d model(s) beam_size=%d beam_prune=%s beam_search_stop=%s " "ensemble_mode=%s batch_size=%d buckets_source=%s)", len(self.models), @@ -1304,44 +1308,6 @@ def _combine_predictions(self, neg_logprobs = self.interpolation_func(probs) return neg_logprobs, attention_prob_score - def _prune(self, - accumulated_scores: mx.nd.NDArray, - best_word_indices: mx.nd.NDArray, - inactive: mx.nd.NDArray, - finished: mx.nd.NDArray) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]: - """ - Prunes the beam. For each sentence, we find the best-scoring completed hypothesis (if any), - and then remove all hypotheses for that sentence that are outside the beam relative to that - item. Pruned items are marked by setting their entry in `inactive` to 1 and marking them as finished. - The four arguments are updated in place. - - Note that after pruning, hypotheses are no longer necessarily sorted until the next call to topk(). - - TODO: this could be rewritten with batch-level operations. - - :param accumulated_scores: The accumulated scores. Shape: (batch * beam, 1). - :param best_word_indices: The row indices indicating the best hypotheses. Shape: (batch * beam). - :param inactive: Marks inactive items in the beam. Shape: (batch * beam). - :param finished: Marks completed items in the beam. Shape: (batch * beam). - """ - for sentno in range(self.batch_size): - rows = slice(sentno * self.beam_size, (sentno + 1) * self.beam_size) - if mx.nd.sum(finished[rows]) > 0: - best_finished_score = mx.nd.min(mx.nd.where(finished[rows], - accumulated_scores[rows, 0], - self.inf_array)) - - # Find, mark (by setting the score to inf), and remove all hypotheses - # whose score is not within self.beam_prune of the best score - inactive[rows] = mx.nd.cast(accumulated_scores[rows, 0] - best_finished_score > self.beam_prune, - dtype='int32') - accumulated_scores[rows, 0] = mx.nd.where(inactive[rows], self.inf_array, accumulated_scores[rows, 0]) - best_word_indices[rows] = mx.nd.where(inactive[rows], self.zeros_array, best_word_indices[rows]) - - # mark removed ones as finished so they won't block early exiting - finished[rows] = mx.nd.clip(finished[rows] + inactive[rows], 0, 1) - return accumulated_scores, best_word_indices, inactive, finished - def _beam_search(self, source: mx.nd.NDArray, source_length: int, @@ -1464,15 +1430,17 @@ def _beam_search(self, # (2) Special treatment for finished and inactive rows. Inactive rows are inf everywhere; # finished rows are inf everywhere except column zero, which holds the accumulated model score scores += scores_accumulated - # Items that are finished (but not inactive) get the accumulated score in col 0, - # otherwise infinity for the whole row + # Items that are finished (but not inactive) get their previous accumulated score for the symbol, + # infinity otherwise. + # pylint: disable=invalid-sequence-index pad_dist[:, C.PAD_ID] = mx.nd.where(mx.nd.clip(finished - inactive, 0, 1), scores_accumulated[:, 0], - self.inf_array_long) + self.inf_array[:, 0]) scores = mx.nd.where(finished + inactive, pad_dist, scores) # (3) Get beam_size winning hypotheses for each sentence block separately. Only look as # far as the active beam size for each sentence. + # pylint: disable=unsupported-assignment-operation best_hyp_indices[:], best_word_indices[:], scores_accumulated[:, 0] = self.topk(scores) # Constraints for constrained decoding are processed sentence by sentence @@ -1515,22 +1483,24 @@ def _beam_search(self, # (6) Prune out low-probability hypotheses. Pruning works by setting entries `inactive`. if self.beam_prune > 0.0: - scores_accumulated, best_word_indices, inactive, finished = self._prune(scores_accumulated, - best_word_indices, - inactive, - finished) + inactive = self.prune(scores_accumulated, finished) + best_word_indices = mx.nd.where(inactive, self.zeros_array, best_word_indices) + scores_accumulated = mx.nd.where(inactive, self.inf_array, scores_accumulated) + finished_or_inactive = (finished + inactive).clip(0, 1) # (7) update best hypotheses, their attention lists and lengths (only for non-finished hyps) # pylint: disable=unsupported-assignment-operation sequences[:, t] = best_word_indices attentions[:, t, :] = attention_scores - lengths += mx.nd.cast(1 - mx.nd.expand_dims(finished, axis=1), dtype='float32') + lengths += mx.nd.cast(1 - mx.nd.expand_dims(finished_or_inactive, axis=1), dtype='float32') # (6) optionally save beam history if self.store_beam: - unnormalized_scores = mx.nd.where(finished, scores_accumulated * self.length_penalty(lengths - 1), + unnormalized_scores = mx.nd.where(finished_or_inactive, + scores_accumulated * self.length_penalty(lengths - 1), scores_accumulated) - normalized_scores = mx.nd.where(finished, scores_accumulated, + normalized_scores = mx.nd.where(finished_or_inactive, + scores_accumulated, scores_accumulated / self.length_penalty(lengths - 1)) for sent in range(self.batch_size): rows = slice(sent * self.beam_size, (sent + 1) * self.beam_size) @@ -1611,7 +1581,7 @@ def _get_best_from_beam(self, if any(constraints): # For constrained decoding, select from items that have met all constraints (might not be finished) unmet = mx.nd.array([c.num_needed() if c is not None else 0 for c in constraints], ctx=self.context) - filtered = mx.nd.where(unmet == 0, seq_scores[:, 0], self.inf_array_long) + filtered = mx.nd.where(unmet == 0, seq_scores, self.inf_array) filtered = filtered.reshape((self.batch_size, self.beam_size)) best_ids += mx.nd.argmin(filtered, axis=1) diff --git a/sockeye/utils.py b/sockeye/utils.py index a8adeb164..32d07c309 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -289,6 +289,33 @@ def topk(scores: mx.nd.NDArray, return best_hyp_indices, best_word_indices, values +def prune(scores: mx.nd.NDArray, + finished: mx.nd.NDArray, + inf_array: mx.nd.NDArray, + beam_size: int, + prune_threshold: float) -> mx.nd.NDArray: + """ + Returns a 0-1 array indicating which hypotheses are inactive based on pruning. + Finished hypotheses that have a score worse than prune_threshold from the best scoring hypotheses + are marked as inactive. + + :param scores: Hypotheses scores. Shape: (batch * beam, 1). + :param finished: 0-1 array indicating which hypotheses are finished. Shape: (batch * beam,). + :param inf_array: Auxiliary array filled with infinity. Shape: (batch * beam,). + :param beam_size: Beam size. + :param prune_threshold: Pruning threshold. + :return NDArray of inactive items. Shape(batch * beam,). + """ + scores = scores.reshape((-1, beam_size)) + finished = finished.reshape((-1, beam_size)) + inf_array = inf_array.reshape((-1, beam_size)) + + # best finished scores. Shape: (batch, 1) + best_finished_scores = mx.nd.where(finished, scores, inf_array).min(axis=1, keepdims=True) + inactive = mx.nd.cast((scores - best_finished_scores) > prune_threshold, dtype='int32').reshape((-1)) + return inactive + + def chunks(some_list: List, n: int) -> Iterable[List]: """Yield successive n-sized chunks from l.""" for i in range(0, len(some_list), n): diff --git a/test/unit/test_inference.py b/test/unit/test_inference.py index ec7b02c19..f5e18443a 100644 --- a/test/unit/test_inference.py +++ b/test/unit/test_inference.py @@ -12,17 +12,17 @@ # permissions and limitations under the License. import json +from math import ceil from unittest.mock import patch, Mock import mxnet as mx import numpy as np import pytest -from math import ceil import sockeye.constants as C import sockeye.data_io import sockeye.inference -from sockeye.utils import SockeyeError +import sockeye.utils _BOS = 0 _EOS = -1 @@ -56,14 +56,15 @@ def mock_model(): t_mock = Mock(sockeye.inference.InferenceModel) t_mock.num_source_factors = num_source_factors return t_mock - translator.models = [ mock_model() ] + + translator.models = [mock_model()] translator.batch_size = batch_size translator.beam_size = beam_size translator.beam_prune = beam_prune translator.zeros_array = mx.nd.zeros((beam_size,), dtype='int32') - translator.inf_array_long = mx.nd.full((batch_size * beam_size,), val=np.inf, dtype='float32') - translator.inf_array = mx.nd.slice(translator.inf_array_long, begin=(0), end=(beam_size)) + translator.inf_array = mx.nd.full((batch_size * beam_size,), val=np.inf, dtype='float32') + translator.inf_array = mx.nd.slice(translator.inf_array, begin=(0), end=(beam_size)) return translator @@ -262,7 +263,7 @@ def test_make_input_whitespace_delimiter(delimiter): sentence_id = 1 translator = mock_translator(num_source_factors=2) sentence = "foo" - with pytest.raises(SockeyeError) as e: + with pytest.raises(sockeye.utils.SockeyeError) as e: sockeye.inference.make_input_from_factored_string(sentence_id=sentence_id, factored_string=sentence, translator=translator, delimiter=delimiter) @@ -309,19 +310,7 @@ def test_make_input_from_multiple_strings(strings): assert inp.tokens == expected_tokens assert inp.factors == expected_factors -""" -Test pruning via inference.Translator._beam_prune(). The best score is computed from the best -finished item; all other items whose scores are outside (best_item - threshold) are pruned, which -means their spot in `inactive` is set to 1. - -Tests: take in -- accumulated_scores and finished -- a dummy inactive (not read from) -- best_word_indices (maybe dummy?) -and check -- values of finished and invalid -- maybe values of best_word_indices -""" + # batch size, beam size, prune thresh, accumulated scores, finished, expected_inactive prune_tests = [ # no pruning because nothing is finished @@ -341,24 +330,9 @@ def test_make_input_from_multiple_strings(strings): @pytest.mark.parametrize("batch, beam, prune, scores, finished, expected_inactive", prune_tests) def test_beam_prune(batch, beam, prune, scores, finished, expected_inactive): - translator = mock_translator(batch, beam, prune) - - orig_finished = [x for x in finished] - - # these are passed by reference and changed, so create them here scores = mx.nd.array(scores).expand_dims(axis=1) - inactive = mx.nd.array([0] * (batch * beam), dtype='int32') - best_words = mx.nd.array([10] * (batch * beam), dtype='int32') finished = mx.nd.array(finished, dtype='int32') + inf_array = mx.nd.full((batch * beam,), val=np.inf) - translator._prune(scores, best_words, inactive, finished) - - # Make sure inactive is set as expected + inactive = sockeye.utils.prune(scores, finished, inf_array, beam, prune) assert inactive.asnumpy().tolist() == expected_inactive - - # Ensure that scores for inactive items are set to 'inf' - zeros = mx.nd.zeros((beam * batch,), dtype='float32') - assert mx.nd.where(inactive, scores[:, 0], zeros).asnumpy().tolist() == [np.inf if x == 1 else 0 for x in expected_inactive] - - # Inactive items should also be marked as finished - assert finished.asnumpy().tolist() == np.clip(np.array(orig_finished) + np.array(expected_inactive), 0, 1).tolist()