-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for SentenceTransformers with `deepsparse.sentence_transforme…
…rs.SentenceTransformer` (#1301) * Support for SentenceTransformer with `deepsparse.sentence_transformers.SentenceTransformer` * Format * Update install * Update * Address comments * Add README * Fix docs * Update setup.py * Update README * Add batching
- Loading branch information
Showing
4 changed files
with
341 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
|
||
# DeepSparse SentenceTransformers | ||
|
||
```python | ||
from deepsparse.sentence_transformers import SentenceTransformer | ||
``` | ||
|
||
[DeepSparse](https://github.com/neuralmagic/deepsparse) enhances [SentenceTransformers](https://www.sbert.net/), enabling more efficient computation of embeddings for text and images across numerous languages. This improvement hinges on advanced sparse inference methods from DeepSparse and provides performance improvements on CPUs as a result. The system, originally built on PyTorch and Transformers, gains additional muscle from DeepSparse, expanding its repertoire of pre-trained models. It's especially adept at tasks like identifying similar meanings in text, supporting applications in semantic search, paraphrase detection, and more. | ||
|
||
## Installation | ||
|
||
You can install the DeepSparse SentenceTransformers extension using pip: | ||
|
||
```bash | ||
pip install -U deepsparse-nightly[sentence_transformers] | ||
``` | ||
|
||
## Usage | ||
|
||
Using DeepSparse SentenceTransformers is straightforward and similar to the original: | ||
|
||
```python | ||
from deepsparse.sentence_transformers import SentenceTransformer | ||
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', export=True) | ||
|
||
# Our sentences we like to encode | ||
sentences = ['This framework generates embeddings for each input sentence', | ||
'Sentences are passed as a list of string.', | ||
'The quick brown fox jumps over the lazy dog.'] | ||
|
||
# Sentences are encoded by calling model.encode() | ||
embeddings = model.encode(sentences) | ||
|
||
# Print the embeddings | ||
for sentence, embedding in zip(sentences, embeddings): | ||
print("Sentence:", sentence) | ||
print("Embedding:", embedding.shape) | ||
print("") | ||
``` | ||
|
||
## Accuracy Validation with MTEB | ||
|
||
DeepSparse's efficiency doesn't compromise its accuracy, thanks to testing with the Multilingual Text Embedding Benchmark (MTEB). This process validates the model's performance against standard tasks, ensuring its reliability. | ||
|
||
To initiate this, you'll need to install MTEB, along with the necessary DeepSparse and SentenceTransformers libraries. Use the following command: | ||
|
||
``` | ||
pip install mteb deepsparse-nightly[sentence_transformers] sentence-transformers | ||
``` | ||
|
||
Once installed, you can leverage MTEB for an evaluation as shown in the Python script below: | ||
|
||
```python | ||
from mteb import MTEB | ||
|
||
# Specify the model to use | ||
model_name = "TaylorAI/bge-micro-v2" | ||
|
||
# DeepSparse Model Evaluation | ||
from deepsparse.sentence_transformers import SentenceTransformer | ||
model = SentenceTransformer(model_name, export=True) | ||
evaluation = MTEB(tasks=["Banking77Classification"]) | ||
results_ds = evaluation.run(model, output_folder=f"results/ds-{model_name}") | ||
print(results_ds) | ||
|
||
# Original SentenceTransformers Model Evaluation | ||
import sentence_transformers | ||
model = sentence_transformers.SentenceTransformer(model_name) | ||
evaluation = MTEB(tasks=["Banking77Classification"]) | ||
results_st = evaluation.run(model, output_folder=f"results/st-{model_name}") | ||
print(results_st) | ||
``` | ||
|
||
Output: | ||
``` | ||
{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 8.05}}} | ||
{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 12.21}}} | ||
``` | ||
|
||
This script performs a comparative analysis between the DeepSparse-optimized model and the original SentenceTransformers model, using MTEB's "Banking77Classification" task as a benchmark. The results are then saved in separate directories for a clear, side-by-side comparison. This thorough evaluation ensures that the enhancements provided by DeepSparse maintain the high standards of accuracy expected from state-of-the-art NLP models. | ||
|
||
--- | ||
|
||
This documentation is based on the original README from [SentenceTransformers](https://www.sbert.net/). It extends the original functionalities with the optimizations provided by [DeepSparse](https://github.com/neuralmagic/deepsparse). | ||
|
||
**Note**: The example usage is designed for the DeepSparse-enhanced version of SentenceTransformers. Make sure to follow the specific installation instructions for full compatibility. Performance optimizations with batching and other advanced features will be part of future updates. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Helpers for running SentenceTransformer based models with DeepSparse and integrating with | ||
huggingface/transformers | ||
""" | ||
|
||
# flake8: noqa | ||
|
||
from deepsparse.analytics import deepsparse_analytics as _analytics | ||
|
||
|
||
_analytics.send_event("python__sentence_transformers__init") | ||
|
||
|
||
try: | ||
import optimum.deepsparse | ||
except ImportError: | ||
raise ImportError( | ||
"Please install deepsparse[sentence_transformers] to use this pathway" | ||
) | ||
|
||
|
||
from .sentence_transformer import SentenceTransformer |
216 changes: 216 additions & 0 deletions
216
src/deepsparse/sentence_transformers/sentence_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import Dict, List, Tuple, Union | ||
|
||
import numpy as np | ||
from tqdm.autonotebook import trange | ||
from transformers.onnx.utils import get_preprocessor | ||
|
||
import torch | ||
from optimum.deepsparse import DeepSparseModelForFeatureExtraction | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant" | ||
|
||
|
||
class SentenceTransformer: | ||
""" | ||
Loads or creates a SentenceTransformer-compatible model that can be used to map | ||
text to embeddings. | ||
:param model_name_or_path: If it is a filepath on disc, it loads the model from | ||
that path. If it is not a path, it first tries to download and export a model | ||
from a HuggingFace models repository with that name. | ||
:param export: To load a PyTorch checkpoint and convert it to the DeepSparse | ||
format on-the-fly, you can set `export=True` when loading your model. | ||
:param max_seq_length: Sets a limit on the maxmimum sequence length allowed, | ||
this should be set to 512 for most models. Any text that exceeds this | ||
token length will be truncated. | ||
:param use_auth_token: HuggingFace authentication token to download private models. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name_or_path: str = DEFAULT_MODEL_NAME, | ||
export: bool = False, | ||
max_seq_length: int = 512, | ||
use_auth_token: Union[bool, str, None] = None, | ||
): | ||
|
||
self.model_name_or_path = model_name_or_path | ||
self.model = DeepSparseModelForFeatureExtraction.from_pretrained( | ||
model_name_or_path, export=export, use_auth_token=use_auth_token | ||
) | ||
self.model.compile(batch_size=0) | ||
self.tokenizer = get_preprocessor(model_name_or_path) | ||
|
||
self._max_seq_length = max_seq_length | ||
|
||
def encode( | ||
self, | ||
sentences: Union[str, List[str]], | ||
batch_size: int = 1, | ||
show_progress_bar: bool = None, | ||
output_value: str = "sentence_embedding", | ||
convert_to_numpy: bool = True, | ||
convert_to_tensor: bool = False, | ||
normalize_embeddings: bool = False, | ||
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: | ||
""" | ||
Computes sentence embeddings | ||
:param sentences: the sentences to embed | ||
:param batch_size: the batch size used for the computation | ||
:param show_progress_bar: Output a progress bar when encode sentences | ||
:param output_value: Default sentence_embedding, to get sentence embeddings. | ||
Can be set to token_embeddings to get wordpiece token embeddings. Set to | ||
None, to get all output values | ||
:param convert_to_numpy: If true, the output is a list of numpy vectors. | ||
Else, it is a list of PyTorch tensors. | ||
:param convert_to_tensor: If true, you get one large tensor as return. | ||
Overwrites any setting from convert_to_numpy | ||
:param normalize_embeddings: If set to true, returned vectors will have | ||
length 1. In that case, the faster dot-product (util.dot_score) | ||
instead of cosine similarity can be used. | ||
:return: | ||
By default, a list of tensors is returned. If convert_to_tensor, | ||
a stacked tensor is returned. If convert_to_numpy, a numpy matrix | ||
is returned. | ||
""" | ||
|
||
if show_progress_bar is None: | ||
show_progress_bar = logger.getEffectiveLevel() in ( | ||
logging.INFO, | ||
logging.DEBUG, | ||
) | ||
|
||
if convert_to_tensor: | ||
convert_to_numpy = False | ||
|
||
if output_value != "sentence_embedding": | ||
convert_to_tensor = False | ||
convert_to_numpy = False | ||
|
||
input_was_string = False | ||
if isinstance(sentences, str) or not hasattr( | ||
sentences, "__len__" | ||
): # Cast an individual sentence to a list with length 1 | ||
sentences = [sentences] | ||
input_was_string = True | ||
|
||
all_embeddings = [] | ||
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) | ||
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] | ||
|
||
for start_index in trange( | ||
0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar | ||
): | ||
sentences_batch = sentences_sorted[start_index : start_index + batch_size] | ||
|
||
model_inputs = self.tokenize(sentences_batch) | ||
model_output = self.model(**model_inputs) | ||
|
||
out_features = {} | ||
out_features["sentence_embedding"] = self.mean_pooling( | ||
model_output, model_inputs["attention_mask"] | ||
) | ||
|
||
embeddings = [] | ||
if output_value == "token_embeddings": | ||
for token_emb, attention in zip( | ||
out_features[output_value], out_features["attention_mask"] | ||
): | ||
# Apply the attention mask to remove embeddings for padding tokens | ||
# Count non-zero values in the attention mask | ||
actual_tokens_count = attention.sum().item() | ||
# Slice the embeddings using this count | ||
embeddings.append(token_emb[:actual_tokens_count]) | ||
elif output_value is None: | ||
# Return all outputs | ||
for sent_idx in range(len(out_features["sentence_embedding"])): | ||
row = {name: out_features[name][sent_idx] for name in out_features} | ||
embeddings.append(row) | ||
else: | ||
# Sentence embeddings | ||
embeddings = out_features[output_value] | ||
if normalize_embeddings: | ||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | ||
|
||
all_embeddings.extend(embeddings) | ||
|
||
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] | ||
|
||
if convert_to_tensor: | ||
all_embeddings = torch.stack(all_embeddings) | ||
elif convert_to_numpy: | ||
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | ||
|
||
if input_was_string: | ||
all_embeddings = all_embeddings[0] | ||
|
||
return all_embeddings | ||
|
||
def get_max_seq_length(self) -> int: | ||
""" | ||
Returns the maximal sequence length for input the model accepts. | ||
Longer inputs will be truncated | ||
""" | ||
return self._max_seq_length | ||
|
||
def _text_length(self, text: Union[List[int], List[List[int]]]) -> int: | ||
""" | ||
Help function to get the length for the input text. Text can be either | ||
a list of ints (which means a single text as input), or a tuple of list of ints | ||
(representing several text inputs to the model). | ||
""" | ||
|
||
if isinstance(text, dict): # {key: value} case | ||
return len(next(iter(text.values()))) | ||
elif not hasattr(text, "__len__"): # Object has no len() method | ||
return 1 | ||
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints | ||
return len(text) | ||
else: | ||
return sum([len(t) for t in text]) # Sum of length of individual strings | ||
|
||
def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): | ||
""" | ||
Tokenizes the texts | ||
""" | ||
return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | ||
|
||
def mean_pooling( | ||
self, model_output: torch.Tensor, attention_mask: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Compute mean pooling of token embeddings weighted by attention mask. | ||
Args: | ||
model_output (torch.Tensor): The model's output tensor. | ||
attention_mask (torch.Tensor): The attention mask tensor. | ||
Returns: | ||
torch.Tensor: Mean-pooled embeddings. | ||
""" | ||
# First element of model_output contains all token embeddings | ||
token_embeddings = model_output[0] | ||
input_mask_expanded = ( | ||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
) | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | ||
input_mask_expanded.sum(1), min=1e-9 | ||
) |