Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of perplexity computation #1197

Merged
merged 98 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
11ca2b5
Add input_tokes as optional output
anmarques Aug 22, 2023
530d625
Refactor Perplexity class to only compute perplexity. All other task-…
anmarques Aug 22, 2023
c816922
Simplify perplexity evaluation. Evaluation takes place as batch size …
anmarques Aug 22, 2023
5c89d89
Splits wikitext at regular intervals of the same length as the sequen…
anmarques Aug 22, 2023
5767ca0
Add argument for accumulation of negative log likelihood
anmarques Aug 22, 2023
ec2162e
Accumulate likelihood for wikitext
anmarques Aug 22, 2023
a7941ef
Simplification
anmarques Aug 22, 2023
3ddd45c
Add support for wikitext-style ppl evaluation
anmarques Aug 23, 2023
756169c
Compute batch instead of storing until compute method. This drastical…
anmarques Aug 23, 2023
97b5f1a
Remove torch dependency
anmarques Aug 23, 2023
91b5921
Move split of dataset into helper function
anmarques Aug 23, 2023
e7c9342
Merge branch 'main' into research/ppl_refactor
anmarques Aug 23, 2023
8ef20e7
Quality fixes
anmarques Aug 23, 2023
5a60228
Remove debugging prints
anmarques Aug 24, 2023
2559e41
Remove debugging prints
anmarques Aug 24, 2023
3b7e14b
Incorporate fixes for kv-cache
anmarques Aug 24, 2023
b5f845b
Include doc string for accumulate
anmarques Aug 25, 2023
6f3b246
Add support to trust-remote-code arguments
anmarques Aug 25, 2023
42f0da2
Merge branch 'main' into research/ppl_refactor
anmarques Aug 25, 2023
2056ec5
Add support to c4
anmarques Aug 25, 2023
8f15636
Merge branch 'main' into research/ppl_refactor
dbogunowicz Aug 28, 2023
858bee6
add a missing include_prompt_logits param
dbogunowicz Aug 28, 2023
4f6eb6b
Remove unnecessary capping at sequence length (it's incorrect for cac…
anmarques Aug 28, 2023
51370d4
Merge branch 'main' into research/ppl_refactor
anmarques Aug 28, 2023
b3d99d1
Merge branch 'main' into research/ppl_refactor
anmarques Aug 28, 2023
53d8a09
Merge branch 'main' into research/ppl_refactor
anmarques Aug 29, 2023
ab757d0
Simplify processing for concatenated datasets
anmarques Aug 29, 2023
bc0920a
Merge branch 'main' into research/ppl_refactor
anmarques Sep 1, 2023
f21eaf3
Fix kv cache update
anmarques Sep 1, 2023
2a18c45
Fix kv cache update
anmarques Sep 1, 2023
7e8da1c
Quality fixes
anmarques Sep 1, 2023
717f518
Merge branch 'main' into research/ppl_refactor
anmarques Sep 1, 2023
c8187f6
Merge branch 'main' into research/ppl_refactor
dbogunowicz Sep 6, 2023
1f9c358
remove batch size from pipeline instantiation
anmarques Sep 8, 2023
b2d2827
Merge branch 'main' into research/ppl_refactor
anmarques Sep 8, 2023
099b366
Rename to wikitext2
anmarques Sep 8, 2023
5455c7c
Remove trust_remote_code argument
anmarques Sep 8, 2023
6a330d4
Remove use_deepsparse_cache argument
anmarques Sep 8, 2023
e19eb83
Merge branch 'main' into research/ppl_refactor
bfineran Sep 11, 2023
a448667
Change padding of output to left in order to match padding of input i…
anmarques Sep 11, 2023
7b5d8e5
Merge remote-tracking branch 'origin/research/ppl_refactor' into rese…
anmarques Sep 11, 2023
54b560c
Allow trust_remote_code to be passed as argument (in some cases token…
anmarques Sep 11, 2023
ad35340
Move process_concatenated_datasets to helpers file
anmarques Sep 11, 2023
b16a5f6
Added support for max_text_length to speed up processing of long data…
anmarques Sep 13, 2023
065864a
Rebase w/ main
anmarques Sep 20, 2023
7583e28
Merge branch 'main' into research/ppl_refactor
anmarques Sep 20, 2023
59b93c5
Rebase w/ main
anmarques Sep 20, 2023
f4554b1
Fix typo
anmarques Sep 20, 2023
15c031a
Merge branch 'main' into research/ppl_refactor
anmarques Sep 22, 2023
1ba429c
Merge branch 'main' into research/ppl_refactor
anmarques Sep 25, 2023
5530895
Merge branch 'main' into research/ppl_refactor
anmarques Sep 26, 2023
c5bd383
Rebase
anmarques Sep 26, 2023
0672e0d
Merge branch 'main' into research/ppl_refactor
anmarques Sep 26, 2023
091aeca
Use max_length instead of max_new_tokens
anmarques Sep 26, 2023
d4d7a36
Merge branch 'main' into research/ppl_refactor
anmarques Sep 27, 2023
75417f2
Merge branch 'main' into research/ppl_refactor
anmarques Sep 28, 2023
8751886
Merge branch 'main' into research/ppl_refactor
anmarques Sep 29, 2023
cebad84
Merge branch 'main' into research/ppl_refactor
anmarques Oct 2, 2023
141c966
Merge branch 'main' into research/ppl_refactor
anmarques Oct 3, 2023
6bc08bc
Rebase
anmarques Oct 3, 2023
dc943d7
Added typing and docstring
anmarques Oct 3, 2023
8f3743a
Added typing and docstring
anmarques Oct 3, 2023
5e1d808
Define concantenated datasets
anmarques Oct 3, 2023
0785321
Add warning about batch-size not being a supported argument for some …
anmarques Oct 3, 2023
d8914f0
Add unit test for pipeline and generation in ppl eval
anmarques Oct 3, 2023
5bf076b
Add lifecycle in docstring
anmarques Oct 3, 2023
2e56b50
Merge branch 'main' into research/ppl_refactor
anmarques Oct 4, 2023
96c5794
Merge branch 'main' into research/ppl_refactor
anmarques Oct 10, 2023
e2ecca4
Merge branch 'main' into research/ppl_refactor
anmarques Oct 11, 2023
3e536c5
Merge branch 'main' into research/ppl_refactor
anmarques Oct 13, 2023
cb08231
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
ecf3b77
Add copyright
anmarques Oct 20, 2023
fe37c32
Style fixes
anmarques Oct 20, 2023
ddd0325
Quality fixes
anmarques Oct 20, 2023
24a91a3
Quality fixes
anmarques Oct 20, 2023
301115c
Quality fixes
anmarques Oct 20, 2023
e402da9
Quality fixes
anmarques Oct 20, 2023
b48e05f
Quality fixes
anmarques Oct 20, 2023
61b9c5c
Quality fixes
anmarques Oct 20, 2023
8c08e84
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
34ee8f6
Quality fixes
anmarques Oct 20, 2023
b032101
Quality fixes
anmarques Oct 20, 2023
f1de171
Merge branch 'main' into research/ppl_refactor
anmarques Oct 20, 2023
f3cbf3d
Quality fixes
anmarques Oct 20, 2023
483449e
Quality fixes
anmarques Oct 20, 2023
a329e25
Merge branch 'main' into research/ppl_refactor
anmarques Oct 23, 2023
fe4b267
Merge branch 'main' into research/ppl_refactor
anmarques Oct 23, 2023
c55a05e
Merge branch 'main' into research/ppl_refactor
anmarques Oct 24, 2023
5d46ddf
Merge branch 'main' into research/ppl_refactor
anmarques Oct 26, 2023
bf139c2
Merge branch 'main' into research/ppl_refactor
anmarques Oct 30, 2023
e6e7828
Rebase
anmarques Oct 30, 2023
d7c6e5a
Rebase
anmarques Oct 30, 2023
f8c64a1
Merge branch 'main' into research/ppl_refactor
rahul-tuli Nov 1, 2023
28919b1
Merge branch 'main' into research/ppl_refactor
anmarques Nov 8, 2023
21c6f0d
Re-add unit test
anmarques Nov 8, 2023
fa0cb4b
Style fix
anmarques Nov 8, 2023
bf1b0cf
Update unit test
anmarques Nov 8, 2023
0c618a6
Update unit test
anmarques Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 119 additions & 22 deletions src/deepsparse/transformers/eval_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,49 +62,112 @@

import argparse
import json
import logging
from cProfile import Profile
from pstats import Stats

import numpy
from tqdm.auto import tqdm

from datasets import load_dataset, load_metric
from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline
from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1
from deepsparse.transformers.utils.eval_helpers import process_concatenated_datasets


from datasets import load_dataset, load_metric # isort: skip
_LOGGER = logging.getLogger(__name__)


def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"):
if args.max_samples:
batch_size = min(batch_size, args.max_samples)
PPL_DATASETS = ["wikitext2", "c4", "openai_humaneval"]

dataset = load_dataset(dataset_name)["test"]

def perplexity_eval(args, dataset_name="openai_humaneval"):
if dataset_name in ["wikitext2", "c4"]:
if args.kwargs is None:
kwargs = {}
else:
kwargs = json.loads(args.kwargs)
dataset = process_concatenated_datasets(
dataset_name,
args.model_path,
args.max_sequence_length,
kwargs,
)
# Set perplexity computation to accumulate negative log-likelihood across
# sections
accumulate = True
else:
dataset = load_dataset(dataset_name, split="test")
accumulate = False

# We'll use the text generation pipeline to generate a single token.
# Along with the token, it returns the logits for input sequence
text_generation = Pipeline.create(
task="text-generation",
model_path=args.model_path,
engine_type=args.engine,
num_cores=args.num_cores,
sequence_length=args.max_sequence_length,
max_generated_tokens=1,
trust_remote_code=args.trust_remote_code,
anmarques marked this conversation as resolved.
Show resolved Hide resolved
)
perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size)
active_engines = [
engine
for engine in [text_generation.engine, text_generation.multitoken_engine]
if engine
]
print("Engine info: ")
[print(f"{engine}\n") for engine in active_engines]
predictions = []

# Instantiate perplexity metric
perplexity_metrics = Perplexity(accumulate=accumulate)

# Loop through samples
batch_samples = []
run_inference = False
end_evaluation = False
dataset_length = len(dataset)
for idx, sample in _enumerate_progress(dataset, args.max_samples):
predictions.append(sample["prompt"] + sample["canonical_solution"])
if len(predictions) == batch_size:
perplexity_metrics.add_batch(predictions)
predictions = []
if args.max_samples and idx >= args.max_samples:
# Collect input sequence
if dataset_name == "openai_humaneval":
sample = sample["prompt"] + sample["canonical_solution"]
batch_samples.append(sample)

if args.max_samples and idx == args.max_samples - 1:
run_inference = True
end_evaluation = True

if (idx + 1) % args.batch_size == 0 or idx == dataset_length - 1:
run_inference = True

if run_inference:
# Perform single token generation
prediction = text_generation(
sequences=batch_samples,
anmarques marked this conversation as resolved.
Show resolved Hide resolved
output_scores=True,
return_input_tokens=True,
fixed_sequences_length=True,
include_prompt_logits=True,
max_length=1,
)

# Handle one sample at a time to make it simpler for masking
for s in range(len(batch_samples)):
# Need to remove tokens that were masked
input_ids = prediction.input_tokens["input_ids"][s].flatten()
logits = prediction.generations[s].score
attention_mask = prediction.input_tokens["attention_mask"][s].flatten()

effective_sequence_length = logits.shape[0]

input_ids = input_ids[-effective_sequence_length:]
attention_mask = attention_mask[-effective_sequence_length:]

logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :]
anmarques marked this conversation as resolved.
Show resolved Hide resolved
input_ids = numpy.compress(attention_mask, input_ids)[1:]

# Add predictions (logits) and targets (input_ids) to metric
perplexity_metrics.add_batch(logits, input_ids)

# Reset batch
batch_samples.clear()
run_inference = False

if end_evaluation:
break

return perplexity_metrics


Expand Down Expand Up @@ -473,7 +536,18 @@ def _split_train_val(train_dataset, val_ratio, seed=42):
"imdb": imdb_eval,
"conll2003": conll2003_eval,
"go_emotions": go_emotions_eval,
"openai_humaneval": perplexity_eval,
"openai_humaneval": lambda args: perplexity_eval(
args,
dataset_name="openai_humaneval",
),
"wikitext2": lambda args: perplexity_eval(
args,
dataset_name="wikitext2",
),
"c4": lambda args: perplexity_eval(
args,
dataset_name="c4",
),
}


Expand Down Expand Up @@ -604,7 +678,24 @@ def parse_args():
type=bool,
default=False,
)

parser.add_argument(
anmarques marked this conversation as resolved.
Show resolved Hide resolved
"--batch-size",
help="Batch size with which to evaluate model. Default is 1",
type=int,
default=1,
)
parser.add_argument(
"--trust-remote-code",
help="Whether to allow for remote code execution in transformers.",
type=bool,
default=False,
)
parser.add_argument(
"--kwargs",
help="Additional arguments specific to each dataset",
type=str,
default=None,
)
return parser.parse_args()


Expand All @@ -619,6 +710,12 @@ def _main(args):
f"available datasets are {list(SUPPORTED_DATASETS.keys())}"
)

if dataset not in PPL_DATASETS:
_LOGGER.warning(
"Batch-size argument is not supported for this dataset."
"Will use default value of 1."
)

if dataset == "mnli":
mnli_metrics_matched, mnli_metrics_mismatched = mnli_eval(args)
mnli_metrics_matched = mnli_metrics_matched.compute()
Expand Down
Loading
Loading