Skip to content

Commit

Permalink
[Feat] Add streaming to chatbot (#1272)
Browse files Browse the repository at this point in the history
* Add changes needed on parent class side

* Add stream mode to chatbot
Clean functions

* Update sampling temperature to temperature

* Update src/deepsparse/transformers/infer.py

* Rebase and fix streaming

* Fix broken data pathway

* Add stream to data mode
update arg name to prompt

* Update prompt prefill tok/s

---------

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
rahul-tuli and mgoin authored Oct 16, 2023
1 parent 3dfc256 commit e7ec235
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 36 deletions.
100 changes: 65 additions & 35 deletions src/deepsparse/transformers/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
--task TEXT The task to use for the pipeline. Choose any
of `chat`, `codegen`, `text-generation`
[default: chat]
--help Show this message and exit.
--stream / --no_stream Whether to stream output as generated or not
[default: no_stream]
--help Show this message and exit. [default:
False]
Installation: pip install deepsparse[transformers]
Examples:
Expand All @@ -62,6 +65,10 @@
4) Disable history
deepsparse.infer models/llama/deployment \
--task text-generation
5) Stream output
deepsparse.infer models/llama/deployment \
--stream
"""

from typing import Optional
Expand Down Expand Up @@ -106,7 +113,7 @@
@click.option(
"--prompt_sequence_length",
type=int,
default=64,
default=16,
help="Processed prompt in chunks of this length. "
"This is to maximize the inference speed",
)
Expand All @@ -117,11 +124,17 @@
)
@click.option(
"--task",
default="chat",
default="text-generation",
type=str,
help="The task to use for the pipeline. Choose any of "
"`chat`, `codegen`, `text-generation`",
)
@click.option(
"--stream/--no_stream",
is_flag=True,
default=False,
help="Whether to stream output as generated or not",
)
def main(
model_path: str,
data: Optional[str],
Expand All @@ -130,6 +143,7 @@ def main(
prompt_sequence_length: int,
show_tokens_per_sec: bool,
task: str,
stream: bool,
):
"""
Command Line utility to interact with a text genration LLM in a chatbot style
Expand All @@ -152,7 +166,6 @@ def main(
default_prompt_kwargs = {
"sequence_length": sequence_length,
"sampling_temperature": sampling_temperature,
"prompt_sequence_length": prompt_sequence_length,
"show_tokens_per_sec": show_tokens_per_sec,
}

Expand All @@ -161,57 +174,74 @@ def main(
task=task,
pipeline=pipeline,
session_ids=session_ids,
stream=stream,
**prompt_kwargs,
)
return

# continue prompts until a keyboard interrupt
while data is None: # always True in interactive Mode
prompt = input(">>> ")
while True:
input_text = input("User: ")
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
pipeline=pipeline,
sampling_temperature=sampling_temperature,
task=task,
session_ids=session_ids,
show_tokens_per_sec=show_tokens_per_sec,
stream=stream,
prompt=input_text,
)


def _run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**kwargs,
pipeline: Pipeline,
sampling_temperature: float,
task: str,
session_ids: str,
show_tokens_per_sec: bool,
prompt: str,
stream: bool = False,
):
pipeline_inputs = dict(
prompt=[prompt],
temperature=sampling_temperature,
**kwargs,
)

if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("\n", response.generations[0].text)
response = pipeline(**pipeline_inputs, streaming=stream)
_display_bot_response(stream, response)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)
_display_generation_speed(prompt, pipeline)


def _display_generation_speed(prompt, pipeline):
# display prefill and generation speed(s) in tokens/sec
times = pipeline.timer_manager.times
prefill_speed = (
len(pipeline.tokenizer(prompt)["input_ids"]) / times["engine_prompt_prefill"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


def _display_bot_response(stream: bool, response):
# print response from pipeline, streaming or not

print("Bot:", end="", flush=True)
if stream:
for generation in response:
print(generation.generations[0].text, end="", flush=True)
print()
else:
print(response.generations[0].text)


if __name__ == "__main__":
if "__main__" == __name__:
main()
1 change: 0 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ def initialize_engines(
if (
self.cache_support_enabled and self.enable_multitoken_prefill
) or not self.cache_support_enabled:

# input_ids_length for the multitoken engine is either:
# - the prompt_sequence_length if the cache support is enabled
# (the prompt is processed sequentially at predefined processing length)
Expand Down

0 comments on commit e7ec235

Please sign in to comment.