Skip to content

Commit

Permalink
Refactor of perplexity computation (#1197)
Browse files Browse the repository at this point in the history
* Add input_tokes as optional output

* Refactor Perplexity class to only compute perplexity. All other task-specific processing is handled elsewhere

* Simplify perplexity evaluation. Evaluation takes place as batch size 1 only, so no need to consider batched execution. In addition, use input_tokens from generation pipeline

* Splits wikitext at regular intervals of the same length as the sequence length

* Add argument for accumulation of negative log likelihood

* Accumulate likelihood for wikitext

* Simplification

* Add support for wikitext-style ppl evaluation

* Compute batch instead of storing until compute method. This drastically reduced memory requirements

* Remove torch dependency

* Move split of dataset into helper function

* Quality fixes

* Remove debugging prints

* Remove debugging prints

* Incorporate fixes for kv-cache

* Include doc string for accumulate

* Add support to trust-remote-code arguments

* Add support to c4

* add a missing include_prompt_logits param

* Remove unnecessary capping at sequence length (it's incorrect for cached models)

* Simplify processing for concatenated datasets

* Fix kv cache update

* Fix kv cache update

* Quality fixes

* remove batch size from pipeline instantiation

* Rename to wikitext2

* Remove trust_remote_code argument

* Remove use_deepsparse_cache argument

* Change padding of output to left in order to match padding of input ids and attention mask

* Allow trust_remote_code to be passed as argument (in some cases tokenizer can be defined by custom code)

* Move process_concatenated_datasets to helpers file

* Added support for max_text_length to speed up processing of long datasets

* Rebase w/ main

* Rebase w/ main

* Fix typo

* Rebase

* Use max_length instead of max_new_tokens

* Rebase

* Added typing and docstring

* Added typing and docstring

* Define concantenated datasets

* Add warning about batch-size not being a supported argument for some datasets

* Add unit test for pipeline and generation in ppl eval

* Add lifecycle in docstring

* Add copyright

* Style fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Quality fixes

* Rebase

* Rebase

* Re-add unit test

* Style fix

* Update unit test

* Update unit test

---------

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
5 people authored Nov 10, 2023
1 parent 91d3735 commit 86490b0
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 167 deletions.
141 changes: 119 additions & 22 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,49 +62,112 @@

import argparse
import json
import logging
from cProfile import Profile
from pstats import Stats

import numpy
from tqdm.auto import tqdm

from datasets import load_dataset, load_metric
from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline
from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1
from deepsparse.transformers.utils.eval_helpers import process_concatenated_datasets


from datasets import load_dataset, load_metric # isort: skip
_LOGGER = logging.getLogger(__name__)


def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"):
if args.max_samples:
batch_size = min(batch_size, args.max_samples)
PPL_DATASETS = ["wikitext2", "c4", "openai_humaneval"]

dataset = load_dataset(dataset_name)["test"]

def perplexity_eval(args, dataset_name="openai_humaneval"):
if dataset_name in ["wikitext2", "c4"]:
if args.kwargs is None:
kwargs = {}
else:
kwargs = json.loads(args.kwargs)
dataset = process_concatenated_datasets(
dataset_name,
args.model_path,
args.max_sequence_length,
kwargs,
)
# Set perplexity computation to accumulate negative log-likelihood across
# sections
accumulate = True
else:
dataset = load_dataset(dataset_name, split="test")
accumulate = False

# We'll use the text generation pipeline to generate a single token.
# Along with the token, it returns the logits for input sequence
text_generation = Pipeline.create(
task="text-generation",
model_path=args.model_path,
engine_type=args.engine,
num_cores=args.num_cores,
sequence_length=args.max_sequence_length,
max_generated_tokens=1,
trust_remote_code=args.trust_remote_code,
)
perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size)
active_engines = [
engine
for engine in [text_generation.engine, text_generation.multitoken_engine]
if engine
]
print("Engine info: ")
[print(f"{engine}\n") for engine in active_engines]
predictions = []

# Instantiate perplexity metric
perplexity_metrics = Perplexity(accumulate=accumulate)

# Loop through samples
batch_samples = []
run_inference = False
end_evaluation = False
dataset_length = len(dataset)
for idx, sample in _enumerate_progress(dataset, args.max_samples):
predictions.append(sample["prompt"] + sample["canonical_solution"])
if len(predictions) == batch_size:
perplexity_metrics.add_batch(predictions)
predictions = []
if args.max_samples and idx >= args.max_samples:
# Collect input sequence
if dataset_name == "openai_humaneval":
sample = sample["prompt"] + sample["canonical_solution"]
batch_samples.append(sample)

if args.max_samples and idx == args.max_samples - 1:
run_inference = True
end_evaluation = True

if (idx + 1) % args.batch_size == 0 or idx == dataset_length - 1:
run_inference = True

if run_inference:
# Perform single token generation
prediction = text_generation(
sequences=batch_samples,
output_scores=True,
return_input_tokens=True,
fixed_sequences_length=True,
include_prompt_logits=True,
max_length=1,
)

# Handle one sample at a time to make it simpler for masking
for s in range(len(batch_samples)):
# Need to remove tokens that were masked
input_ids = prediction.input_tokens["input_ids"][s].flatten()
logits = prediction.generations[s].score
attention_mask = prediction.input_tokens["attention_mask"][s].flatten()

effective_sequence_length = logits.shape[0]

input_ids = input_ids[-effective_sequence_length:]
attention_mask = attention_mask[-effective_sequence_length:]

logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :]
input_ids = numpy.compress(attention_mask, input_ids)[1:]

# Add predictions (logits) and targets (input_ids) to metric
perplexity_metrics.add_batch(logits, input_ids)

# Reset batch
batch_samples.clear()
run_inference = False

if end_evaluation:
break

return perplexity_metrics


Expand Down Expand Up @@ -473,7 +536,18 @@ def _split_train_val(train_dataset, val_ratio, seed=42):
"imdb": imdb_eval,
"conll2003": conll2003_eval,
"go_emotions": go_emotions_eval,
"openai_humaneval": perplexity_eval,
"openai_humaneval": lambda args: perplexity_eval(
args,
dataset_name="openai_humaneval",
),
"wikitext2": lambda args: perplexity_eval(
args,
dataset_name="wikitext2",
),
"c4": lambda args: perplexity_eval(
args,
dataset_name="c4",
),
}


Expand Down Expand Up @@ -604,7 +678,24 @@ def parse_args():
type=bool,
default=False,
)

parser.add_argument(
"--batch-size",
help="Batch size with which to evaluate model. Default is 1",
type=int,
default=1,
)
parser.add_argument(
"--trust-remote-code",
help="Whether to allow for remote code execution in transformers.",
type=bool,
default=False,
)
parser.add_argument(
"--kwargs",
help="Additional arguments specific to each dataset",
type=str,
default=None,
)
return parser.parse_args()


Expand All @@ -619,6 +710,12 @@ def _main(args):
f"available datasets are {list(SUPPORTED_DATASETS.keys())}"
)

if dataset not in PPL_DATASETS:
_LOGGER.warning(
"Batch-size argument is not supported for this dataset."
"Will use default value of 1."
)

if dataset == "mnli":
mnli_metrics_matched, mnli_metrics_mismatched = mnli_eval(args)
mnli_metrics_matched = mnli_metrics_matched.compute()
Expand Down
Loading

0 comments on commit 86490b0

Please sign in to comment.