From 7a4cc1d5d28579b6a3c9d29f050c51fb84e6cdef Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 13 Oct 2023 09:57:48 -0400 Subject: [PATCH] Update --- .../sentence_transformer.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/deepsparse/sentence_transformers/sentence_transformer.py b/src/deepsparse/sentence_transformers/sentence_transformer.py index 296079e725..99a88e78f4 100644 --- a/src/deepsparse/sentence_transformers/sentence_transformer.py +++ b/src/deepsparse/sentence_transformers/sentence_transformer.py @@ -25,12 +25,15 @@ logger = logging.getLogger(__name__) -DEFAULT_MODEL_NAME = "zeroshot/oneshot-minilm" +DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant" class SentenceTransformer: def __init__( - self, model_name_or_path: str = DEFAULT_MODEL_NAME, export: bool = False + self, + model_name_or_path: str = DEFAULT_MODEL_NAME, + export: bool = False, + max_seq_length: int = 512, ): self.model_name_or_path = model_name_or_path @@ -39,7 +42,7 @@ def __init__( ) self.tokenizer = get_preprocessor(model_name_or_path) - self._max_seq_length = 512 + self._max_seq_length = max_seq_length self._batch_size = 1 def encode( @@ -52,11 +55,9 @@ def encode( convert_to_tensor: bool = False, normalize_embeddings: bool = False, ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: + + # TODO: support executing with batch size > 1 batch_size = 1 - # if batch_size != self._batch_size: - # self._batch_size = batch_size - # self.model.reshape(input_shapes=f"[{self._batch_size},{self.get_max_seq_length()}]") - # self.model.compile(batch_size=self._batch_size) if show_progress_bar is None: show_progress_bar = ( @@ -105,12 +106,14 @@ def encode( last_mask_id -= 1 embeddings.append(token_emb[0 : last_mask_id + 1]) - elif output_value is None: # Return all outputs + elif output_value is None: + # Return all outputs embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])): row = {name: out_features[name][sent_idx] for name in out_features} embeddings.append(row) - else: # Sentence embeddings + else: + # Sentence embeddings embeddings = out_features[output_value] if normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) @@ -129,14 +132,14 @@ def encode( return all_embeddings - def get_max_seq_length(self): + def get_max_seq_length(self) -> int: """ Returns the maximal sequence length for input the model accepts. Longer inputs will be truncated """ return self._max_seq_length - def _text_length(self, text: Union[List[int], List[List[int]]]): + def _text_length(self, text: Union[List[int], List[List[int]]]) -> int: """ Help function to get the length for the input text. Text can be either a list of ints (which means a single text as input), or a tuple of list of ints @@ -157,10 +160,10 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): Tokenizes the texts """ return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") - # return self.tokenizer(texts, padding='max_length', truncation=True, - # max_length=self.get_max_seq_length(), return_tensors="pt") - def mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor): + def mean_pooling( + self, model_output: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: """ Compute mean pooling of token embeddings weighted by attention mask. Args: