Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of perplexity computation #1197

Merged
merged 98 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
11ca2b5
Add input_tokes as optional output
anmarques Aug 22, 2023
530d625
Refactor Perplexity class to only compute perplexity. All other task-…
anmarques Aug 22, 2023
c816922
Simplify perplexity evaluation. Evaluation takes place as batch size …
anmarques Aug 22, 2023
5c89d89
Splits wikitext at regular intervals of the same length as the sequen…
anmarques Aug 22, 2023
5767ca0
Add argument for accumulation of negative log likelihood
anmarques Aug 22, 2023
ec2162e
Accumulate likelihood for wikitext
anmarques Aug 22, 2023
a7941ef
Simplification
anmarques Aug 22, 2023
3ddd45c
Add support for wikitext-style ppl evaluation
anmarques Aug 23, 2023
756169c
Compute batch instead of storing until compute method. This drastical…
anmarques Aug 23, 2023
97b5f1a
Remove torch dependency
anmarques Aug 23, 2023
91b5921
Move split of dataset into helper function
anmarques Aug 23, 2023
e7c9342
Merge branch 'main' into research/ppl_refactor
anmarques Aug 23, 2023
8ef20e7
Quality fixes
anmarques Aug 23, 2023
5a60228
Remove debugging prints
anmarques Aug 24, 2023
2559e41
Remove debugging prints
anmarques Aug 24, 2023
3b7e14b
Incorporate fixes for kv-cache
anmarques Aug 24, 2023
b5f845b
Include doc string for accumulate
anmarques Aug 25, 2023
6f3b246
Add support to trust-remote-code arguments
anmarques Aug 25, 2023
42f0da2
Merge branch 'main' into research/ppl_refactor
anmarques Aug 25, 2023
2056ec5
Add support to c4
anmarques Aug 25, 2023
8f15636
Merge branch 'main' into research/ppl_refactor
dbogunowicz Aug 28, 2023
858bee6
add a missing include_prompt_logits param
dbogunowicz Aug 28, 2023
4f6eb6b
Remove unnecessary capping at sequence length (it's incorrect for cac…
anmarques Aug 28, 2023
51370d4
Merge branch 'main' into research/ppl_refactor
anmarques Aug 28, 2023
b3d99d1
Merge branch 'main' into research/ppl_refactor
anmarques Aug 28, 2023
53d8a09
Merge branch 'main' into research/ppl_refactor
anmarques Aug 29, 2023
ab757d0
Simplify processing for concatenated datasets
anmarques Aug 29, 2023
bc0920a
Merge branch 'main' into research/ppl_refactor
anmarques Sep 1, 2023
f21eaf3
Fix kv cache update
anmarques Sep 1, 2023
2a18c45
Fix kv cache update
anmarques Sep 1, 2023
7e8da1c
Quality fixes
anmarques Sep 1, 2023
717f518
Merge branch 'main' into research/ppl_refactor
anmarques Sep 1, 2023
c8187f6
Merge branch 'main' into research/ppl_refactor
dbogunowicz Sep 6, 2023
1f9c358
remove batch size from pipeline instantiation
anmarques Sep 8, 2023
b2d2827
Merge branch 'main' into research/ppl_refactor
anmarques Sep 8, 2023
099b366
Rename to wikitext2
anmarques Sep 8, 2023
5455c7c
Remove trust_remote_code argument
anmarques Sep 8, 2023
6a330d4
Remove use_deepsparse_cache argument
anmarques Sep 8, 2023
e19eb83
Merge branch 'main' into research/ppl_refactor
bfineran Sep 11, 2023
a448667
Change padding of output to left in order to match padding of input i…
anmarques Sep 11, 2023
7b5d8e5
Merge remote-tracking branch 'origin/research/ppl_refactor' into rese…
anmarques Sep 11, 2023
54b560c
Allow trust_remote_code to be passed as argument (in some cases token…
anmarques Sep 11, 2023
ad35340
Move process_concatenated_datasets to helpers file
anmarques Sep 11, 2023
b16a5f6
Added support for max_text_length to speed up processing of long data…
anmarques Sep 13, 2023
065864a
Rebase w/ main
anmarques Sep 20, 2023
7583e28
Merge branch 'main' into research/ppl_refactor
anmarques Sep 20, 2023
59b93c5
Rebase w/ main
anmarques Sep 20, 2023
f4554b1
Fix typo
anmarques Sep 20, 2023
15c031a
Merge branch 'main' into research/ppl_refactor
anmarques Sep 22, 2023
1ba429c
Merge branch 'main' into research/ppl_refactor
anmarques Sep 25, 2023
5530895
Merge branch 'main' into research/ppl_refactor
anmarques Sep 26, 2023
c5bd383
Rebase
anmarques Sep 26, 2023
0672e0d
Merge branch 'main' into research/ppl_refactor
anmarques Sep 26, 2023
091aeca
Use max_length instead of max_new_tokens
anmarques Sep 26, 2023
d4d7a36
Merge branch 'main' into research/ppl_refactor
anmarques Sep 27, 2023
75417f2
Merge branch 'main' into research/ppl_refactor
anmarques Sep 28, 2023
8751886
Merge branch 'main' into research/ppl_refactor
anmarques Sep 29, 2023
cebad84
Merge branch 'main' into research/ppl_refactor
anmarques Oct 2, 2023
141c966
Merge branch 'main' into research/ppl_refactor
anmarques Oct 3, 2023
6bc08bc
Rebase
anmarques Oct 3, 2023
dc943d7
Added typing and docstring
anmarques Oct 3, 2023
8f3743a
Added typing and docstring
anmarques Oct 3, 2023
5e1d808
Define concantenated datasets
anmarques Oct 3, 2023
0785321
Add warning about batch-size not being a supported argument for some …
anmarques Oct 3, 2023
d8914f0
Add unit test for pipeline and generation in ppl eval
anmarques Oct 3, 2023
5bf076b
Add lifecycle in docstring
anmarques Oct 3, 2023
2e56b50
Merge branch 'main' into research/ppl_refactor
anmarques Oct 4, 2023
96c5794
Merge branch 'main' into research/ppl_refactor
anmarques Oct 10, 2023
e2ecca4
Merge branch 'main' into research/ppl_refactor
anmarques Oct 11, 2023
3e536c5
Merge branch 'main' into research/ppl_refactor
anmarques Oct 13, 2023
cb08231
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
ecf3b77
Add copyright
anmarques Oct 20, 2023
fe37c32
Style fixes
anmarques Oct 20, 2023
ddd0325
Quality fixes
anmarques Oct 20, 2023
24a91a3
Quality fixes
anmarques Oct 20, 2023
301115c
Quality fixes
anmarques Oct 20, 2023
e402da9
Quality fixes
anmarques Oct 20, 2023
b48e05f
Quality fixes
anmarques Oct 20, 2023
61b9c5c
Quality fixes
anmarques Oct 20, 2023
8c08e84
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
34ee8f6
Quality fixes
anmarques Oct 20, 2023
b032101
Quality fixes
anmarques Oct 20, 2023
f1de171
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
f3cbf3d
Quality fixes
anmarques Oct 20, 2023
483449e
Quality fixes
anmarques Oct 20, 2023
a329e25
Merge branch 'main' into research/ppl_refactor
anmarques Oct 23, 2023
fe4b267
Merge branch 'main' into research/ppl_refactor
anmarques Oct 23, 2023
c55a05e
Merge branch 'main' into research/ppl_refactor
anmarques Oct 24, 2023
5d46ddf
Merge branch 'main' into research/ppl_refactor
anmarques Oct 26, 2023
bf139c2
Merge branch 'main' into research/ppl_refactor
anmarques Oct 30, 2023
e6e7828
Rebase
anmarques Oct 30, 2023
d7c6e5a
Rebase
anmarques Oct 30, 2023
f8c64a1
Merge branch 'main' into research/ppl_refactor
rahul-tuli Nov 1, 2023
28919b1
Merge branch 'main' into research/ppl_refactor
anmarques Nov 8, 2023
21c6f0d
Re-add unit test
anmarques Nov 8, 2023
fa0cb4b
Style fix
anmarques Nov 8, 2023
bf1b0cf
Update unit test
anmarques Nov 8, 2023
0c618a6
Update unit test
anmarques Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 104 additions & 23 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,56 @@
from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline
from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1


from transformers import AutoTokenizer
from datasets import load_dataset, load_metric # isort: skip


def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"):
if args.max_samples:
batch_size = min(batch_size, args.max_samples)

dataset = load_dataset(dataset_name)["test"]
def perplexity_eval(args, dataset_name="openai_humaneval"):
if dataset_name == "wikitext":
anmarques marked this conversation as resolved.
Show resolved Hide resolved
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

# Dataset is split into sections that contain "max_sequence_length" tokens.
# To split the dataset, first tokenize text
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
raw_text = "\n\n".join(raw_dataset["text"])
input_tokens = tokenizer(
raw_text,
return_tensors="np",
)["input_ids"][0]

# Then split the tokenized text into sections of size "max_sequence_length" and
# decode each section back into text format
dataset = []
for i in range(len(input_tokens) // args.max_sequence_length):
start = i * args.max_sequence_length
end = (i+1) * args.max_sequence_length
dataset.append(
tokenizer.decode(
input_tokens[start:end],
clean_up_tokenization_spaces=False,
)
)

# Handle any leftover tokens
if (i+1) * args.max_sequence_length < len(input_tokens):
start = (i+1) * args.max_sequence_length
end = len(input_tokens)
dataset.append(
tokenizer.decode(
input_tokens[start:end],
clean_up_tokenization_spaces=False,
)
)

# Set perplexity computation to accumulate negative log-likelihood across
# sections
accumulate = True
else:
dataset = load_dataset(dataset_name, split="test")
accumulate = False

# We'll use the text generation pipeline to generate a single token.
# Along with the token, it returns the logits for input sequence
text_generation = Pipeline.create(
task="text-generation",
model_path=args.model_path,
Expand All @@ -90,22 +130,58 @@ def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"):
prompt_processing_sequence_length=args.max_sequence_length,
max_generated_tokens=1,
)
perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size)
active_engines = [
engine
for engine in [text_generation.engine, text_generation.multitoken_engine]
if engine
]
print("Engine info: ")
[print(f"{engine}\n") for engine in active_engines]
predictions = []

# Instantiate perplexity metric
perplexity_metrics = Perplexity(accumulate=accumulate)

# Loop through samples
batch_samples = []
run_inference = False
end_evaluation = False
dataset_length = len(dataset)
for idx, sample in _enumerate_progress(dataset, args.max_samples):
predictions.append(sample["prompt"] + sample["canonical_solution"])
if len(predictions) == batch_size:
perplexity_metrics.add_batch(predictions)
predictions = []
if args.max_samples and idx >= args.max_samples:

# Collect input sequence
if dataset_name == "openai_humaneval":
sample = sample["prompt"] + sample["canonical_solution"]
batch_samples.append(sample)

if args.max_samples and idx == args.max_samples - 1:
run_inference = True
end_evaluation = True

if (idx + 1) % args.batch_size == 0 or idx == dataset_length - 1:
run_inference = True

if run_inference:
# Perform single token generation
prediction = text_generation(
sequences=batch_samples,
anmarques marked this conversation as resolved.
Show resolved Hide resolved
return_logits=True,
return_input_tokens=True,
fixed_sequences_length=True,
)

# Handle one sample at a time to make it simpler for masking
for s in range(len(batch_samples)):
# Need to remove tokens that were masked
input_ids = prediction.input_tokens["input_ids"][s].flatten()
logits = prediction.logits[s]
attention_mask = prediction.input_tokens["attention_mask"][s].flatten()

logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :]
anmarques marked this conversation as resolved.
Show resolved Hide resolved
input_ids = numpy.compress(attention_mask, input_ids)[1:]

# Add predictions (logits) and targets (input_ids) to metric
perplexity_metrics.add_batch(logits, input_ids)

# Reset batch
batch_samples.clear()
run_inference = False

if end_evaluation:
break

return perplexity_metrics


Expand Down Expand Up @@ -474,10 +550,10 @@ def _split_train_val(train_dataset, val_ratio, seed=42):
"imdb": imdb_eval,
"conll2003": conll2003_eval,
"go_emotions": go_emotions_eval,
"openai_humaneval": perplexity_eval,
"openai_humaneval": lambda args: perplexity_eval(args, dataset_name="openai_humaneval"),
"wikitext": lambda args: perplexity_eval(args, dataset_name="wikitext"),
}


def parse_args():
parser = argparse.ArgumentParser(
description="Evaluate a Hugging Face Transformers "
Expand Down Expand Up @@ -605,7 +681,12 @@ def parse_args():
type=bool,
default=False,
)

parser.add_argument(
anmarques marked this conversation as resolved.
Show resolved Hide resolved
"--batch-size",
help="Batch size to evaluate model. Default is 1",
type=int,
default=1,
)
return parser.parse_args()


Expand Down
208 changes: 82 additions & 126 deletions src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,11 @@
Utilities for evaluation metric computation
"""


from itertools import compress
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional

import numpy
from tqdm import tqdm

import torch
from deepsparse import Pipeline
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
from sklearn.metrics import precision_recall_fscore_support


Expand All @@ -37,134 +31,96 @@


class Perplexity:
def __init__(self, pipeline: Pipeline, batch_size: int = 16):
"""
Given the pipeline, compute the perplexity of the model
on the given text input.

Code adapted from:
https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501

:param pipeline: The pipeline to use for text generation
:param batch_size: The batch size to split the input text into
non-overlapping batches
def __init__(self, accumulate: bool = False):
anmarques marked this conversation as resolved.
Show resolved Hide resolved
"""
if not isinstance(pipeline, TextGenerationPipeline):
raise ValueError(
"Perplexity can only be computed for text generation pipelines"
)
self._pipeline = pipeline
self._batch_size = batch_size
self._sequence_length = pipeline.sequence_length
self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

self.perplexities = []

def add_batch(self, predictions: List[str]):
Class for computing perplexity.
"""
anmarques marked this conversation as resolved.
Show resolved Hide resolved
Run the model on the given input sequences and compute the perplexity.
The resulting perplexity is appended to the list of perplexities.
self._predictions = None
self._targets = None
self._accumulate = accumulate
if accumulate:
self._neg_log_likelihood = 0.
self._number_tokens = 0
else:
self._perplexities = None

:param predictions: The predictions to compute perplexity on
def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray):
"""
Computes perplexity or negative log-likelihood for each batch
(depending on accumulate argument)
and track results.

Tracks perplexity or negative log-likelihood since storing
predictions may require a lot of memory.

:param predictions: predicted scores.
Accepted shapes:
- [batch_size, sequence_length, vocab_size]
- [sequence_length, vocab_size] (batch size = 1)
Note: sequence length has to be uniform within a batch, but not all
batches require the same sequence length
:param targets: target values - index of correct vocabulary entry
"""
# tokenize the input text
encodings = self._pipeline.tokenizer(
predictions,
return_attention_mask=True,
max_length=self._sequence_length,
truncation=True,
padding="max_length",
)

encoded_texts = encodings["input_ids"]
attention_masks = encodings["attention_mask"]

for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)):
end_index = min(start_index + self._batch_size, len(encoded_texts))
encoded_batch = encoded_texts[start_index:end_index]
attention_mask = attention_masks[start_index:end_index]

# Computing the ground truth labels

# `encoded_batch` contains sequences of tokens padded
# with <PAD> tokens from the left side. We need to remove
# them and zero-pad from the right side up to the length
# of the longest sequence in the batch

encoded_batch = [
list(compress(sequence, attn_mask))
for (sequence, attn_mask) in zip(encoded_batch, attention_mask)
]
max_sequence_len = max([len(sequence) for sequence in encoded_batch])

encoded_batch = [
pad_to_fixed_length(numpy.array(sequence), max_sequence_len)
for sequence in encoded_batch
]
encoded_batch = numpy.stack(encoded_batch)

# We need to apply the analogous transformation to the attention mask
attention_mask = numpy.array(attention_mask)
attention_mask = [
list(filter(lambda num: num != 0, mask)) for mask in attention_mask
]
attention_mask = [
pad_to_fixed_length(numpy.array(mask), max_sequence_len)
for mask in attention_mask
]
attention_mask = numpy.stack(attention_mask)

labels = encoded_batch

out = self._pipeline(
sequences=predictions, return_logits=True, fixed_sequences_length=True
)

logits = out.logits

if not self._pipeline.cache_support_enabled:
# when running inference without cache, we need to apply
# analogous transformations to the logits as we did to the labels
# and attention mask

# remove "nonsensical" logits for <PAD> tokens
logits = [
logit[-attn_mask.sum() :, :]
for (logit, attn_mask) in zip(logits, attention_mask)
]
# pad logits to max length
logits = [
pad_to_fixed_length(logit, max_sequence_len) for logit in logits
]
logits = numpy.stack(logits)

# shift logits and labels create the input and target for the loss function
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
shift_attention_mask_batch = attention_mask[:, 1:]

# compute perplexity for this batch
perplexity_batch = torch.exp(
(
self._loss_fct(
torch.tensor(shift_logits.transpose(0, 2, 1)),
torch.tensor(shift_labels),
)
* torch.tensor(shift_attention_mask_batch)
).sum(1)
/ torch.tensor(shift_attention_mask_batch).sum(1)
)
self.perplexities.extend(perplexity_batch.numpy().tolist())
if self._accumulate:
# If accumulate is True, every token from the batch contributes equally to the
# negative log-likelihood.
# Thus, merge batch and sequence length dimensions and compute negative
# log-likelihood for all tokens, and accumulate to total
predictions = numpy.reshape(predictions, (-1, predictions.shape[-1]))
targets = targets.flatten()

# Compute negative log-likelihood and accumulate
self._neg_log_likelihood += torch.nn.functional.cross_entropy(
torch.tensor(predictions),
torch.tensor(targets),
reduction="sum",
).item()

# Track number of tokens processed
self._number_tokens += predictions.shape[0]
else:
# If accumulate is False, compute perplexity for each sample individually.
# We assume that sequence length is uniform within a batch, but may vary from batch
# to batch.

# Create batch dimension if it doesn't exist
if targets.ndim == 1:
predictions = numpy.expand_dims(predictions, axis=0)
targets = numpy.expand_dims(targets, axis=0)

# Compute negative log-likelihoods for batch
neg_log_likelihoods = torch.nn.functional.cross_entropy(
torch.tensor(predictions.transpose(0, 2, 1)),
torch.tensor(targets),
reduction="none",
).numpy().mean(-1)

# Compute perplexities for batch
perplexities = numpy.exp(neg_log_likelihoods)

# Store perplexities
if self._perplexities is None:
self._perplexities = perplexities
else:
self._perplexities = numpy.concatenate((self._perplexities, perplexities))

def compute(self) -> Dict[str, Any]:
anmarques marked this conversation as resolved.
Show resolved Hide resolved
"""
:return: A dictionary containing the mean perplexity
and the list of perplexities
:return: A dictionary containing the final results.
If accumulate is True, return single perplexity.
Else, return a list of perplexities (one for each sample)
and mean perplexity.
"""
return {
"mean_perplexity": numpy.mean(self.perplexities),
"perplexities": self.perplexities,
}

if self._accumulate:
perplexity = numpy.exp(self._neg_log_likelihood / self._number_tokens)
return {"perplexity": perplexity}
else:
return {
"perplexities": self._perplexities,
"mean_perplexity": numpy.mean(self._perplexities),
}


class PrecisionRecallF1:
Expand Down
Loading