Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Oct 13, 2023
1 parent 95181ac commit 7a4cc1d
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions src/deepsparse/sentence_transformers/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@

logger = logging.getLogger(__name__)

DEFAULT_MODEL_NAME = "zeroshot/oneshot-minilm"
DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant"


class SentenceTransformer:
def __init__(
self, model_name_or_path: str = DEFAULT_MODEL_NAME, export: bool = False
self,
model_name_or_path: str = DEFAULT_MODEL_NAME,
export: bool = False,
max_seq_length: int = 512,
):

self.model_name_or_path = model_name_or_path
Expand All @@ -39,7 +42,7 @@ def __init__(
)
self.tokenizer = get_preprocessor(model_name_or_path)

self._max_seq_length = 512
self._max_seq_length = max_seq_length
self._batch_size = 1

def encode(
Expand All @@ -52,11 +55,9 @@ def encode(
convert_to_tensor: bool = False,
normalize_embeddings: bool = False,
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:

# TODO: support executing with batch size > 1
batch_size = 1
# if batch_size != self._batch_size:
# self._batch_size = batch_size
# self.model.reshape(input_shapes=f"[{self._batch_size},{self.get_max_seq_length()}]")
# self.model.compile(batch_size=self._batch_size)

if show_progress_bar is None:
show_progress_bar = (
Expand Down Expand Up @@ -105,12 +106,14 @@ def encode(
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
elif output_value is None:
# Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
else: # Sentence embeddings
else:
# Sentence embeddings
embeddings = out_features[output_value]
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
Expand All @@ -129,14 +132,14 @@ def encode(

return all_embeddings

def get_max_seq_length(self):
def get_max_seq_length(self) -> int:
"""
Returns the maximal sequence length for input the model accepts.
Longer inputs will be truncated
"""
return self._max_seq_length

def _text_length(self, text: Union[List[int], List[List[int]]]):
def _text_length(self, text: Union[List[int], List[List[int]]]) -> int:
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
Expand All @@ -157,10 +160,10 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
Tokenizes the texts
"""
return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# return self.tokenizer(texts, padding='max_length', truncation=True,
# max_length=self.get_max_seq_length(), return_tensors="pt")

def mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor):
def mean_pooling(
self, model_output: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Compute mean pooling of token embeddings weighted by attention mask.
Args:
Expand Down

0 comments on commit 7a4cc1d

Please sign in to comment.