From 9bac61eec6ebd06e8c18f220e5921f82a1aaa95f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 14 Mar 2024 17:51:28 -0400 Subject: [PATCH] [TextGeneration] Fix llama tokenizer (#1635) * add llama tokenizer fix * fix generated string * only run for streaming * add TODO --------- Co-authored-by: Dipika Sikka --- .../text_generation/prep_for_generation.py | 1 + .../text_generation/process_outputs.py | 50 +++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py index 3318ec88c5..66b0c2a79b 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py @@ -101,6 +101,7 @@ def run( else [], "finished_reason": [], "token_generator": token_generator, + "past_tokens_queue": copy.copy(tokens), } if kv_cache is None: diff --git a/src/deepsparse/transformers/pipelines/text_generation/process_outputs.py b/src/deepsparse/transformers/pipelines/text_generation/process_outputs.py index 6033e10ea4..cae7e24599 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/process_outputs.py +++ b/src/deepsparse/transformers/pipelines/text_generation/process_outputs.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -from typing import Optional +from typing import List, Optional import numpy @@ -54,6 +54,33 @@ def _create_generated_text_output( finished=False, ) + def _generate_streamed_text_from_past_tokens( + self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int] + ) -> str: + """ + An auxiliary method that helps to properly generate the streamed text. + Some models like llama2 and mistral are using LlamaTokenizer which is + based on SentencePiece tokenizer. This specific tokenizer doesn't seem + to output appropriate prefix spaces when decoding token by token. + One can make it work if the previously generated tokens are included. + This allows the tokenizer to figure out that the appropriate spaces + from last n consecutive tokens. + + :param generated_tokens: the generated tokens from the engine + :param past_tokens_queue: the queue of last n tokens (n is the + original prompt length in tokens) + :return: the generated string + """ + string_from_n_tokens = self.tokenizer.decode( + past_tokens_queue, skip_special_tokens=True + ) + past_tokens_queue.append(generated_tokens[0]) + string_from_n_plus_1_tokens = self.tokenizer.decode( + past_tokens_queue, skip_special_tokens=True + ) + past_tokens_queue.pop(0) + return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]] + def run( self, generated_tokens: numpy.ndarray, @@ -64,9 +91,24 @@ def run( ): generation_config = inference_state.current_state.get("generation_config") generated_logits = generated_logits if generation_config.output_scores else None - sequences = self.tokenizer.batch_decode( - generated_tokens, skip_special_tokens=True - ) + + import transformers + + # Fix for LLAMA-specific models when running streaming + # TODO: make streaming a conditional input to this operator. using inference + # state is a quick fix. + if isinstance( + self.tokenizer, + (transformers.LlamaTokenizer, transformers.LlamaTokenizerFast), + ) and inference_state.current_state.get("streaming"): + past_tokens_queue = inference_state.current_state.get("past_tokens_queue") + sequences = self._generate_streamed_text_from_past_tokens( + generated_tokens, past_tokens_queue + ) + else: + sequences = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) try: finished_reason = [f[-1] for f in finished_reason]