-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from SkywardAI/feat/log
Redesign the methods by using classmethod
- Loading branch information
Showing
14 changed files
with
338 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .auto_cli import CommandAuto | ||
from .auto_cli import CommandAutoModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# coding=utf-8 | ||
# Copyright [2024] [SkywardAI] | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from kimchima.pkg import logging | ||
|
||
from .model_factory import ModelFactory | ||
from .tokenizer_factory import TokenizerFactory | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class EmbeddingsFactory: | ||
r""" | ||
Embeddings class to get embeddings from the specified model and tokenizer. | ||
The embeddings mean pooling is used to get the embeddings from the model, | ||
and the embeddings are normalized using L2 normalization. | ||
Args: | ||
pretrained_model_name_or_path: pretrained model name or path | ||
Returns: | ||
sentence_embeddings: sentence embeddings type torch.Tensor | ||
""" | ||
|
||
@classmethod | ||
def __init__(cls): | ||
raise EnvironmentError( | ||
"Embeddings is designed to be instantiated " | ||
"using the `Embeddings.from_pretrained(pretrained_model_name_or_path)` method." | ||
) | ||
|
||
|
||
@classmethod | ||
def auto_embeddings(cls, *args,**kwargs)-> torch.Tensor: | ||
r""" | ||
Get embeddings from the model. | ||
Args: | ||
prompt: prompt text | ||
device: device to run the model | ||
max_length: maximum length of the input text | ||
""" | ||
model=kwargs.pop('model', None) | ||
tokenizer=kwargs.pop('tokenizer', None) | ||
prompt = kwargs.pop('prompt', None) | ||
device = kwargs.pop('device', 'cpu') | ||
max_length = kwargs.pop('max_length', 512) | ||
|
||
|
||
inputs_ids = tokenizer(prompt, return_tensors='pt',max_length=max_length, padding=True, truncation=True).to(device) | ||
|
||
model=model.to(device) | ||
with torch.no_grad(): | ||
output = model(**inputs_ids) | ||
|
||
embeddings=cls.mean_pooling(model_output=output, attention_mask=inputs_ids['attention_mask']) | ||
logger.debug(f"Embedding mean pooling: {embeddings.shape}") | ||
|
||
# Normalize embeddings | ||
sentence_embeddings = F.normalize(embeddings, p=2, dim=1) | ||
|
||
return sentence_embeddings | ||
|
||
|
||
@classmethod | ||
#Mean Pooling - Take attention mask into account for correct averaging | ||
def mean_pooling(cls, **kwargs) -> torch.Tensor: | ||
r""" | ||
Mean Pooling - Take attention mask into account for correct averaging. | ||
Args: | ||
model_output: model output | ||
attention_mask: attention mask | ||
""" | ||
model_output = kwargs.get('model_output') | ||
attention_mask = kwargs.get('attention_mask') | ||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# coding=utf-8 | ||
# Copyright [2024] [SkywardAI] | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from transformers import AutoModel | ||
from kimchima.pkg import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class ModelFactory: | ||
r""" | ||
ModelFactory class to get the model from the specified model. | ||
Args: | ||
pretrained_model_name_or_path: pretrained model name or path | ||
""" | ||
def __init__(self): | ||
raise EnvironmentError( | ||
"ModelFactory is designed to be instantiated " | ||
"using the `ModelFactory.from_pretrained(pretrained_model_name_or_path)` method." | ||
) | ||
|
||
@classmethod | ||
def auto_model(cls, pretrained_model_name_or_path, **kwargs)-> AutoModel: | ||
r""" | ||
It is used to get the model from the Hugging Face Transformers AutoModel. | ||
Args: | ||
pretrained_model_name_or_path: pretrained model name or path | ||
""" | ||
if pretrained_model_name_or_path is None: | ||
raise ValueError("pretrained_model_name_or_path cannot be None") | ||
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
logger.debug(f"Loaded model: {pretrained_model_name_or_path}") | ||
return model | ||
|
Oops, something went wrong.