Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Oct 16, 2023
1 parent ab2f0b8 commit 5936830
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions src/deepsparse/sentence_transformers/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,32 @@


class SentenceTransformer:
"""
Loads or creates a SentenceTransformer-compatible model that can be used to map
text to embeddings.
:param model_name_or_path: If it is a filepath on disc, it loads the model from
that path. If it is not a path, it first tries to download and export a model
from a HuggingFace models repository with that name.
:param export: To load a PyTorch checkpoint and convert it to the DeepSparse
format on-the-fly, you can set `export=True` when loading your model.
:param max_seq_length: Sets a limit on the maxmimum sequence length allowed,
this should be set to 512 for most models. Any text that exceeds this
token length will be truncated.
:param use_auth_token: HuggingFace authentication token to download private models.
"""

def __init__(
self,
model_name_or_path: str = DEFAULT_MODEL_NAME,
export: bool = False,
max_seq_length: int = 512,
use_auth_token: Union[bool, str, None] = None,
):

self.model_name_or_path = model_name_or_path
self.model = DeepSparseModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=export
model_name_or_path, export=export, use_auth_token=use_auth_token
)
self.tokenizer = get_preprocessor(model_name_or_path)

Expand All @@ -55,14 +71,36 @@ def encode(
convert_to_tensor: bool = False,
normalize_embeddings: bool = False,
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings.
Can be set to token_embeddings to get wordpiece token embeddings. Set to
None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors.
Else, it is a list of PyTorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return.
Overwrites any setting from convert_to_numpy
:param normalize_embeddings: If set to true, returned vectors will have
length 1. In that case, the faster dot-product (util.dot_score)
instead of cosine similarity can be used.
:return:
By default, a list of tensors is returned. If convert_to_tensor,
a stacked tensor is returned. If convert_to_numpy, a numpy matrix
is returned.
"""

# TODO: support executing with batch size > 1
# TODO: support faster execution with batch size > 1
batch_size = 1

if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
show_progress_bar = logger.getEffectiveLevel() in (
logging.INFO,
logging.DEBUG,
)

if convert_to_tensor:
Expand Down Expand Up @@ -96,19 +134,18 @@ def encode(
model_output, model_inputs["attention_mask"]
)

embeddings = []
if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
# Apply the attention mask to remove embeddings for padding tokens
# Count non-zero values in the attention mask
actual_tokens_count = attention.sum().item()
# Slice the embeddings using this count
embeddings.append(token_emb[:actual_tokens_count])
elif output_value is None:
# Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
Expand Down

0 comments on commit 5936830

Please sign in to comment.