Skip to content

Commit

Permalink
Merge pull request #27 from SkywardAI/fix/quantization
Browse files Browse the repository at this point in the history
Remove quantization for supporting CPU
  • Loading branch information
Aisuko authored Apr 7, 2024
2 parents 6157c95 + e9ebfe5 commit a801295
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
20 changes: 19 additions & 1 deletion examples/embeddings_examples.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from kimchima import (
ModelFactory,
TokenizerFactory,
StreamerFactory,
EmbeddingsFactory,
QuantizationFactory,

Check failure on line 6 in examples/embeddings_examples.py

View workflow job for this annotation

GitHub Actions / Code Quality 📦 (ubuntu-latest, 3.11, 1.8.2)

Ruff (F401)

examples/embeddings_examples.py:6:5: F401 `kimchima.QuantizationFactory` imported but unused
PipelinesFactory,
Devices
)

Expand Down Expand Up @@ -36,4 +39,19 @@

# get capability of GPU(Nvidia)
capability = Devices.get_capability()
print(capability)
print(capability)


# streamer
model= ModelFactory.auto_model_for_causal_lm(pretrained_model_name_or_path="gpt2")
tokenizer= TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path="gpt2")
streamer= StreamerFactory.text_streamer(tokenizer=tokenizer, skip_prompt=False, skip_prompt_tokens=False)


pipe=PipelinesFactory.text_generation(
model=model,
tokenizer=tokenizer,
text_streamer=streamer
)

pipe("Melbourne is the capital of Victoria")
4 changes: 0 additions & 4 deletions src/kimchima/pipelines/pipelines_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,13 @@ def text_generation(cls, *args,**kwargs)-> pipeline:
raise ValueError("tokenizer is required")
streamer=kwargs.pop("text_streamer", None)
max_new_tokens=kwargs.pop("max_new_tokens", 20)
quantization_config=kwargs.pop("quantization_config", None)
if quantization_config is None:
raise ValueError("quantization_config is required")

pipe=pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
streamer=streamer,
max_new_tokens=max_new_tokens,
quantization_config=quantization_config,
device_map='auto',
**kwargs
)
Expand Down
23 changes: 19 additions & 4 deletions src/kimchima/pkg/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,47 @@ def __init__(self):
)

@classmethod
def auto_model(cls, pretrained_model_name_or_path, **kwargs)-> AutoModel:
def auto_model(cls, *args, **kwargs)-> AutoModel:
r"""
It is used to get the model from the Hugging Face Transformers AutoModel.
Args:
pretrained_model_name_or_path: pretrained model name or path
"""
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
if pretrained_model_name_or_path is None:
raise ValueError("pretrained_model_name_or_path cannot be None")
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)

quantization_config=kwargs.pop("quantization_config", None)
model = AutoModel.from_pretrained(
pretrained_model_name_or_path,
quantization_config,
**kwargs
)
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
return model

@classmethod
def auto_model_for_causal_lm(cls, pretrained_model_name_or_path, **kwargs)-> AutoModelForCausalLM:
def auto_model_for_causal_lm(cls, *args, **kwargs)-> AutoModelForCausalLM:
r"""
It is used to get the model from the Hugging Face Transformers AutoModelForCausalLM.
Args:
pretrained_model_name_or_path: pretrained model name or path
"""
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
if pretrained_model_name_or_path is None:
raise ValueError("pretrained_model_name_or_path cannot be None")
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)

quantization_config=kwargs.pop("quantization_config", None)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
quantization_config=quantization_config,
device_map='auto',
**kwargs
)
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
return model

0 comments on commit a801295

Please sign in to comment.