From 869af579a8c9d189013589c98582c1c0bcfa0820 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 19 Oct 2023 10:37:06 -0600 Subject: [PATCH] Support for SentenceTransformers with `deepsparse.sentence_transformers.SentenceTransformer` (#1301) * Support for SentenceTransformer with `deepsparse.sentence_transformers.SentenceTransformer` * Format * Update install * Update * Address comments * Add README * Fix docs * Update setup.py * Update README * Add batching --- setup.py | 5 +- .../sentence_transformers/README.md | 86 +++++++ .../sentence_transformers/__init__.py | 36 +++ .../sentence_transformer.py | 216 ++++++++++++++++++ 4 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 src/deepsparse/sentence_transformers/README.md create mode 100644 src/deepsparse/sentence_transformers/__init__.py create mode 100644 src/deepsparse/sentence_transformers/sentence_transformer.py diff --git a/setup.py b/setup.py index 3ebc986e40..1d9404ce9d 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,7 @@ def _parse_requirements_file(file_path): _onnxruntime_deps = [ "onnxruntime>=1.7.0", ] +_torch_deps = ["torch>=1.7.0,<=2.0"] _image_classification_deps = [ "torchvision>=0.3.0,<0.14", "opencv-python<=4.6.0.66", @@ -150,6 +151,7 @@ def _parse_requirements_file(file_path): "scikit-learn", "seqeval", ] +_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps # haystack dependencies are installed from a requirements file to avoid # conflicting versions with NM's deepsparse/transformers @@ -168,8 +170,6 @@ def _parse_requirements_file(file_path): "transformers<4.35", ] -_torch_deps = ["torch>=1.7.0,<=2.0"] - def _check_supported_system(): if sys.platform.startswith("linux"): @@ -275,6 +275,7 @@ def _setup_extras() -> Dict: "yolov8": _yolov8_integration_deps, "transformers": _transformers_integration_deps, "llm": _transformers_integration_deps, + "sentence_transformers": _sentence_transformers_integration_deps, "torch": _torch_deps, "clip": _clip_deps, } diff --git a/src/deepsparse/sentence_transformers/README.md b/src/deepsparse/sentence_transformers/README.md new file mode 100644 index 0000000000..39c843f553 --- /dev/null +++ b/src/deepsparse/sentence_transformers/README.md @@ -0,0 +1,86 @@ + +# DeepSparse SentenceTransformers + +```python +from deepsparse.sentence_transformers import SentenceTransformer +``` + +[DeepSparse](https://github.com/neuralmagic/deepsparse) enhances [SentenceTransformers](https://www.sbert.net/), enabling more efficient computation of embeddings for text and images across numerous languages. This improvement hinges on advanced sparse inference methods from DeepSparse and provides performance improvements on CPUs as a result. The system, originally built on PyTorch and Transformers, gains additional muscle from DeepSparse, expanding its repertoire of pre-trained models. It's especially adept at tasks like identifying similar meanings in text, supporting applications in semantic search, paraphrase detection, and more. + +## Installation + +You can install the DeepSparse SentenceTransformers extension using pip: + +```bash +pip install -U deepsparse-nightly[sentence_transformers] +``` + +## Usage + +Using DeepSparse SentenceTransformers is straightforward and similar to the original: + +```python +from deepsparse.sentence_transformers import SentenceTransformer +model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', export=True) + +# Our sentences we like to encode +sentences = ['This framework generates embeddings for each input sentence', + 'Sentences are passed as a list of string.', + 'The quick brown fox jumps over the lazy dog.'] + +# Sentences are encoded by calling model.encode() +embeddings = model.encode(sentences) + +# Print the embeddings +for sentence, embedding in zip(sentences, embeddings): + print("Sentence:", sentence) + print("Embedding:", embedding.shape) + print("") +``` + +## Accuracy Validation with MTEB + +DeepSparse's efficiency doesn't compromise its accuracy, thanks to testing with the Multilingual Text Embedding Benchmark (MTEB). This process validates the model's performance against standard tasks, ensuring its reliability. + +To initiate this, you'll need to install MTEB, along with the necessary DeepSparse and SentenceTransformers libraries. Use the following command: + +``` +pip install mteb deepsparse-nightly[sentence_transformers] sentence-transformers +``` + +Once installed, you can leverage MTEB for an evaluation as shown in the Python script below: + +```python +from mteb import MTEB + +# Specify the model to use +model_name = "TaylorAI/bge-micro-v2" + +# DeepSparse Model Evaluation +from deepsparse.sentence_transformers import SentenceTransformer +model = SentenceTransformer(model_name, export=True) +evaluation = MTEB(tasks=["Banking77Classification"]) +results_ds = evaluation.run(model, output_folder=f"results/ds-{model_name}") +print(results_ds) + +# Original SentenceTransformers Model Evaluation +import sentence_transformers +model = sentence_transformers.SentenceTransformer(model_name) +evaluation = MTEB(tasks=["Banking77Classification"]) +results_st = evaluation.run(model, output_folder=f"results/st-{model_name}") +print(results_st) +``` + +Output: +``` +{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 8.05}}} +{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 12.21}}} +``` + +This script performs a comparative analysis between the DeepSparse-optimized model and the original SentenceTransformers model, using MTEB's "Banking77Classification" task as a benchmark. The results are then saved in separate directories for a clear, side-by-side comparison. This thorough evaluation ensures that the enhancements provided by DeepSparse maintain the high standards of accuracy expected from state-of-the-art NLP models. + +--- + +This documentation is based on the original README from [SentenceTransformers](https://www.sbert.net/). It extends the original functionalities with the optimizations provided by [DeepSparse](https://github.com/neuralmagic/deepsparse). + +**Note**: The example usage is designed for the DeepSparse-enhanced version of SentenceTransformers. Make sure to follow the specific installation instructions for full compatibility. Performance optimizations with batching and other advanced features will be part of future updates. diff --git a/src/deepsparse/sentence_transformers/__init__.py b/src/deepsparse/sentence_transformers/__init__.py new file mode 100644 index 0000000000..338a9f5dba --- /dev/null +++ b/src/deepsparse/sentence_transformers/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helpers for running SentenceTransformer based models with DeepSparse and integrating with +huggingface/transformers +""" + +# flake8: noqa + +from deepsparse.analytics import deepsparse_analytics as _analytics + + +_analytics.send_event("python__sentence_transformers__init") + + +try: + import optimum.deepsparse +except ImportError: + raise ImportError( + "Please install deepsparse[sentence_transformers] to use this pathway" + ) + + +from .sentence_transformer import SentenceTransformer diff --git a/src/deepsparse/sentence_transformers/sentence_transformer.py b/src/deepsparse/sentence_transformers/sentence_transformer.py new file mode 100644 index 0000000000..a12c27d9e3 --- /dev/null +++ b/src/deepsparse/sentence_transformers/sentence_transformer.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List, Tuple, Union + +import numpy as np +from tqdm.autonotebook import trange +from transformers.onnx.utils import get_preprocessor + +import torch +from optimum.deepsparse import DeepSparseModelForFeatureExtraction + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant" + + +class SentenceTransformer: + """ + Loads or creates a SentenceTransformer-compatible model that can be used to map + text to embeddings. + + :param model_name_or_path: If it is a filepath on disc, it loads the model from + that path. If it is not a path, it first tries to download and export a model + from a HuggingFace models repository with that name. + :param export: To load a PyTorch checkpoint and convert it to the DeepSparse + format on-the-fly, you can set `export=True` when loading your model. + :param max_seq_length: Sets a limit on the maxmimum sequence length allowed, + this should be set to 512 for most models. Any text that exceeds this + token length will be truncated. + :param use_auth_token: HuggingFace authentication token to download private models. + """ + + def __init__( + self, + model_name_or_path: str = DEFAULT_MODEL_NAME, + export: bool = False, + max_seq_length: int = 512, + use_auth_token: Union[bool, str, None] = None, + ): + + self.model_name_or_path = model_name_or_path + self.model = DeepSparseModelForFeatureExtraction.from_pretrained( + model_name_or_path, export=export, use_auth_token=use_auth_token + ) + self.model.compile(batch_size=0) + self.tokenizer = get_preprocessor(model_name_or_path) + + self._max_seq_length = max_seq_length + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 1, + show_progress_bar: bool = None, + output_value: str = "sentence_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + normalize_embeddings: bool = False, + ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: + """ + Computes sentence embeddings + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Output a progress bar when encode sentences + :param output_value: Default sentence_embedding, to get sentence embeddings. + Can be set to token_embeddings to get wordpiece token embeddings. Set to + None, to get all output values + :param convert_to_numpy: If true, the output is a list of numpy vectors. + Else, it is a list of PyTorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. + Overwrites any setting from convert_to_numpy + :param normalize_embeddings: If set to true, returned vectors will have + length 1. In that case, the faster dot-product (util.dot_score) + instead of cosine similarity can be used. + + :return: + By default, a list of tensors is returned. If convert_to_tensor, + a stacked tensor is returned. If convert_to_numpy, a numpy matrix + is returned. + """ + + if show_progress_bar is None: + show_progress_bar = logger.getEffectiveLevel() in ( + logging.INFO, + logging.DEBUG, + ) + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sentence_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + all_embeddings = [] + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange( + 0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar + ): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + + model_inputs = self.tokenize(sentences_batch) + model_output = self.model(**model_inputs) + + out_features = {} + out_features["sentence_embedding"] = self.mean_pooling( + model_output, model_inputs["attention_mask"] + ) + + embeddings = [] + if output_value == "token_embeddings": + for token_emb, attention in zip( + out_features[output_value], out_features["attention_mask"] + ): + # Apply the attention mask to remove embeddings for padding tokens + # Count non-zero values in the attention mask + actual_tokens_count = attention.sum().item() + # Slice the embeddings using this count + embeddings.append(token_emb[:actual_tokens_count]) + elif output_value is None: + # Return all outputs + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + else: + # Sentence embeddings + embeddings = out_features[output_value] + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + all_embeddings.extend(embeddings) + + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if convert_to_tensor: + all_embeddings = torch.stack(all_embeddings) + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + def get_max_seq_length(self) -> int: + """ + Returns the maximal sequence length for input the model accepts. + Longer inputs will be truncated + """ + return self._max_seq_length + + def _text_length(self, text: Union[List[int], List[List[int]]]) -> int: + """ + Help function to get the length for the input text. Text can be either + a list of ints (which means a single text as input), or a tuple of list of ints + (representing several text inputs to the model). + """ + + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints + return len(text) + else: + return sum([len(t) for t in text]) # Sum of length of individual strings + + def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): + """ + Tokenizes the texts + """ + return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + + def mean_pooling( + self, model_output: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + """ + Compute mean pooling of token embeddings weighted by attention mask. + Args: + model_output (torch.Tensor): The model's output tensor. + attention_mask (torch.Tensor): The attention mask tensor. + Returns: + torch.Tensor: Mean-pooled embeddings. + """ + # First element of model_output contains all token embeddings + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + )