From e7ec235f81246fc6884436172b1cc348c9a06384 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 16 Oct 2023 10:06:56 -0400 Subject: [PATCH] [Feat] Add streaming to chatbot (#1272) * Add changes needed on parent class side * Add stream mode to chatbot Clean functions * Update sampling temperature to temperature * Update src/deepsparse/transformers/infer.py * Rebase and fix streaming * Fix broken data pathway * Add stream to data mode update arg name to prompt * Update prompt prefill tok/s --------- Co-authored-by: Michael Goin --- .../transformers/inference/infer.py | 100 ++++++++++++------ .../transformers/pipelines/text_generation.py | 1 - 2 files changed, 65 insertions(+), 36 deletions(-) diff --git a/src/deepsparse/transformers/inference/infer.py b/src/deepsparse/transformers/inference/infer.py index 460b8499c4..0dbfa92a62 100644 --- a/src/deepsparse/transformers/inference/infer.py +++ b/src/deepsparse/transformers/inference/infer.py @@ -43,7 +43,10 @@ --task TEXT The task to use for the pipeline. Choose any of `chat`, `codegen`, `text-generation` [default: chat] - --help Show this message and exit. + --stream / --no_stream Whether to stream output as generated or not + [default: no_stream] + --help Show this message and exit. [default: + False] Installation: pip install deepsparse[transformers] Examples: @@ -62,6 +65,10 @@ 4) Disable history deepsparse.infer models/llama/deployment \ --task text-generation + +5) Stream output +deepsparse.infer models/llama/deployment \ + --stream """ from typing import Optional @@ -106,7 +113,7 @@ @click.option( "--prompt_sequence_length", type=int, - default=64, + default=16, help="Processed prompt in chunks of this length. " "This is to maximize the inference speed", ) @@ -117,11 +124,17 @@ ) @click.option( "--task", - default="chat", + default="text-generation", type=str, help="The task to use for the pipeline. Choose any of " "`chat`, `codegen`, `text-generation`", ) +@click.option( + "--stream/--no_stream", + is_flag=True, + default=False, + help="Whether to stream output as generated or not", +) def main( model_path: str, data: Optional[str], @@ -130,6 +143,7 @@ def main( prompt_sequence_length: int, show_tokens_per_sec: bool, task: str, + stream: bool, ): """ Command Line utility to interact with a text genration LLM in a chatbot style @@ -152,7 +166,6 @@ def main( default_prompt_kwargs = { "sequence_length": sequence_length, "sampling_temperature": sampling_temperature, - "prompt_sequence_length": prompt_sequence_length, "show_tokens_per_sec": show_tokens_per_sec, } @@ -161,57 +174,74 @@ def main( task=task, pipeline=pipeline, session_ids=session_ids, + stream=stream, **prompt_kwargs, ) return # continue prompts until a keyboard interrupt - while data is None: # always True in interactive Mode - prompt = input(">>> ") + while True: + input_text = input("User: ") _run_inference( - pipeline, - sampling_temperature, - task, - session_ids, - show_tokens_per_sec, - prompt_sequence_length, - prompt, + pipeline=pipeline, + sampling_temperature=sampling_temperature, + task=task, + session_ids=session_ids, + show_tokens_per_sec=show_tokens_per_sec, + stream=stream, + prompt=input_text, ) def _run_inference( - pipeline, - sampling_temperature, - task, - session_ids, - show_tokens_per_sec, - prompt_sequence_length, - prompt, - **kwargs, + pipeline: Pipeline, + sampling_temperature: float, + task: str, + session_ids: str, + show_tokens_per_sec: bool, + prompt: str, + stream: bool = False, ): pipeline_inputs = dict( prompt=[prompt], temperature=sampling_temperature, - **kwargs, ) + if SupportedTasks.is_chat(task): pipeline_inputs["session_ids"] = session_ids - response = pipeline(**pipeline_inputs) - print("\n", response.generations[0].text) + response = pipeline(**pipeline_inputs, streaming=stream) + _display_bot_response(stream, response) if show_tokens_per_sec: - times = pipeline.timer_manager.times - prefill_speed = ( - 1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"] - ) - generation_speed = 1.0 / times["engine_token_generation_single"] - print( - f"[prefill: {prefill_speed:.2f} tokens/sec]", - f"[decode: {generation_speed:.2f} tokens/sec]", - sep="\n", - ) + _display_generation_speed(prompt, pipeline) + + +def _display_generation_speed(prompt, pipeline): + # display prefill and generation speed(s) in tokens/sec + times = pipeline.timer_manager.times + prefill_speed = ( + len(pipeline.tokenizer(prompt)["input_ids"]) / times["engine_prompt_prefill"] + ) + generation_speed = 1.0 / times["engine_token_generation_single"] + print( + f"[prefill: {prefill_speed:.2f} tokens/sec]", + f"[decode: {generation_speed:.2f} tokens/sec]", + sep="\n", + ) + + +def _display_bot_response(stream: bool, response): + # print response from pipeline, streaming or not + + print("Bot:", end="", flush=True) + if stream: + for generation in response: + print(generation.generations[0].text, end="", flush=True) + print() + else: + print(response.generations[0].text) -if __name__ == "__main__": +if "__main__" == __name__: main() diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index d0844afc30..06c1d9750c 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -329,7 +329,6 @@ def initialize_engines( if ( self.cache_support_enabled and self.enable_multitoken_prefill ) or not self.cache_support_enabled: - # input_ids_length for the multitoken engine is either: # - the prompt_sequence_length if the cache support is enabled # (the prompt is processed sequentially at predefined processing length)