Skip to content

Commit

Permalink
Add text iterator streamer
Browse files Browse the repository at this point in the history
Signed-off-by: Aisuko <urakiny@gmail.com>
Aisuko committed Apr 7, 2024
1 parent 388981e commit 86804bb
Showing 2 changed files with 29 additions and 3 deletions.
2 changes: 0 additions & 2 deletions src/kimchima/pipelines/pipelines_factory.py
Original file line number Diff line number Diff line change
@@ -43,8 +43,6 @@ def text_generation(cls, *args,**kwargs)-> pipeline:
if tokenizer is None:
raise ValueError("tokenizer is required")
streamer=kwargs.pop("text_streamer", None)
if streamer is None:
raise ValueError("text_streamer is required")
max_new_tokens=kwargs.pop("max_new_tokens", 20)
quantization_config=kwargs.pop("quantization_config", None)
if quantization_config is None:
30 changes: 29 additions & 1 deletion src/kimchima/pkg/streamer_factory.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,10 @@

from kimchima.pkg import logging

from transformers import TextStreamer
from transformers import (
TextStreamer,
TextIteratorStreamer
)

logger = logging.get_logger(__name__)

@@ -55,3 +58,28 @@ def text_streamer(cls, *args, **kwargs)-> TextStreamer:
logger.info("TextStreamer created")

return streamer

@classmethod
def text_iterator_streamer(cls, *args, **kwargs)-> TextIteratorStreamer:
r"""
Get streamer for text generation task.
Args:
skip_prompt: skip prompt
skip_prompt_tokens: skip prompt tokens
"""
#TODO support more parameters
tokenizer=kwargs.pop('tokenizer', None)
if tokenizer is None:
raise ValueError("tokenizer is required")
skip_prompt=kwargs.pop('skip_prompt', False)
skip_prompt_tokens=kwargs.pop('skip_prompt_tokens', False)

streamer=TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=skip_prompt,
skip_prompt_tokens=skip_prompt_tokens
)
logger.info("TextIteratorStreamer created")

return streamer

0 comments on commit 86804bb

Please sign in to comment.