diff --git a/examples/embeddings_examples.py b/examples/embeddings_examples.py index b38b7b2..35ea250 100644 --- a/examples/embeddings_examples.py +++ b/examples/embeddings_examples.py @@ -1,7 +1,10 @@ from kimchima import ( ModelFactory, TokenizerFactory, + StreamerFactory, EmbeddingsFactory, + QuantizationFactory, + PipelinesFactory, Devices ) @@ -36,4 +39,19 @@ # get capability of GPU(Nvidia) capability = Devices.get_capability() -print(capability) \ No newline at end of file +print(capability) + + +# streamer +model= ModelFactory.auto_model_for_causal_lm(pretrained_model_name_or_path="gpt2") +tokenizer= TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path="gpt2") +streamer= StreamerFactory.text_streamer(tokenizer=tokenizer, skip_prompt=False, skip_prompt_tokens=False) + + +pipe=PipelinesFactory.text_generation( + model=model, + tokenizer=tokenizer, + text_streamer=streamer + ) + +pipe("Melbourne is the capital of Victoria") \ No newline at end of file diff --git a/src/kimchima/pipelines/pipelines_factory.py b/src/kimchima/pipelines/pipelines_factory.py index f8b7e2c..1b216fc 100644 --- a/src/kimchima/pipelines/pipelines_factory.py +++ b/src/kimchima/pipelines/pipelines_factory.py @@ -44,9 +44,6 @@ def text_generation(cls, *args,**kwargs)-> pipeline: raise ValueError("tokenizer is required") streamer=kwargs.pop("text_streamer", None) max_new_tokens=kwargs.pop("max_new_tokens", 20) - quantization_config=kwargs.pop("quantization_config", None) - if quantization_config is None: - raise ValueError("quantization_config is required") pipe=pipeline( task="text-generation", @@ -54,7 +51,6 @@ def text_generation(cls, *args,**kwargs)-> pipeline: tokenizer=tokenizer, streamer=streamer, max_new_tokens=max_new_tokens, - quantization_config=quantization_config, device_map='auto', **kwargs ) diff --git a/src/kimchima/pkg/model_factory.py b/src/kimchima/pkg/model_factory.py index 33746d6..9a21343 100644 --- a/src/kimchima/pkg/model_factory.py +++ b/src/kimchima/pkg/model_factory.py @@ -34,7 +34,7 @@ def __init__(self): ) @classmethod - def auto_model(cls, pretrained_model_name_or_path, **kwargs)-> AutoModel: + def auto_model(cls, *args, **kwargs)-> AutoModel: r""" It is used to get the model from the Hugging Face Transformers AutoModel. @@ -42,14 +42,21 @@ def auto_model(cls, pretrained_model_name_or_path, **kwargs)-> AutoModel: pretrained_model_name_or_path: pretrained model name or path """ + pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None) if pretrained_model_name_or_path is None: raise ValueError("pretrained_model_name_or_path cannot be None") - model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) + + quantization_config=kwargs.pop("quantization_config", None) + model = AutoModel.from_pretrained( + pretrained_model_name_or_path, + quantization_config, + **kwargs + ) logger.debug(f"Loaded model: {pretrained_model_name_or_path}") return model @classmethod - def auto_model_for_causal_lm(cls, pretrained_model_name_or_path, **kwargs)-> AutoModelForCausalLM: + def auto_model_for_causal_lm(cls, *args, **kwargs)-> AutoModelForCausalLM: r""" It is used to get the model from the Hugging Face Transformers AutoModelForCausalLM. @@ -57,9 +64,17 @@ def auto_model_for_causal_lm(cls, pretrained_model_name_or_path, **kwargs)-> Aut pretrained_model_name_or_path: pretrained model name or path """ + pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None) if pretrained_model_name_or_path is None: raise ValueError("pretrained_model_name_or_path cannot be None") - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) + + quantization_config=kwargs.pop("quantization_config", None) + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + quantization_config=quantization_config, + device_map='auto', + **kwargs + ) logger.debug(f"Loaded model: {pretrained_model_name_or_path}") return model