From 88447399b4ed28c2247885e9ce9925668b58704f Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Fri, 10 Nov 2023 10:45:40 -0500 Subject: [PATCH] Refactor of perplexity computation (#1197) * Add input_tokes as optional output * Refactor Perplexity class to only compute perplexity. All other task-specific processing is handled elsewhere * Simplify perplexity evaluation. Evaluation takes place as batch size 1 only, so no need to consider batched execution. In addition, use input_tokens from generation pipeline * Splits wikitext at regular intervals of the same length as the sequence length * Add argument for accumulation of negative log likelihood * Accumulate likelihood for wikitext * Simplification * Add support for wikitext-style ppl evaluation * Compute batch instead of storing until compute method. This drastically reduced memory requirements * Remove torch dependency * Move split of dataset into helper function * Quality fixes * Remove debugging prints * Remove debugging prints * Incorporate fixes for kv-cache * Include doc string for accumulate * Add support to trust-remote-code arguments * Add support to c4 * add a missing include_prompt_logits param * Remove unnecessary capping at sequence length (it's incorrect for cached models) * Simplify processing for concatenated datasets * Fix kv cache update * Fix kv cache update * Quality fixes * remove batch size from pipeline instantiation * Rename to wikitext2 * Remove trust_remote_code argument * Remove use_deepsparse_cache argument * Change padding of output to left in order to match padding of input ids and attention mask * Allow trust_remote_code to be passed as argument (in some cases tokenizer can be defined by custom code) * Move process_concatenated_datasets to helpers file * Added support for max_text_length to speed up processing of long datasets * Rebase w/ main * Rebase w/ main * Fix typo * Rebase * Use max_length instead of max_new_tokens * Rebase * Added typing and docstring * Added typing and docstring * Define concantenated datasets * Add warning about batch-size not being a supported argument for some datasets * Add unit test for pipeline and generation in ppl eval * Add lifecycle in docstring * Add copyright * Style fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Quality fixes * Rebase * Rebase * Re-add unit test * Style fix * Update unit test * Update unit test --------- Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Co-authored-by: Damian Co-authored-by: Benjamin Fineran Co-authored-by: Rahul Tuli --- .../transformers/eval_downstream.py | 141 +++- src/deepsparse/transformers/metrics.py | 261 +++--- .../transformers/pipelines/text_generation.py | 24 +- .../transformers/utils/eval_helpers.py | 183 +++++ src/deepsparse/transformers/utils/helpers.py | 4 +- .../pipelines/test_text_generation.py | 775 +++--------------- 6 files changed, 556 insertions(+), 832 deletions(-) create mode 100644 src/deepsparse/transformers/utils/eval_helpers.py diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index 8f928c33be..f9835aa58e 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -62,49 +62,112 @@ import argparse import json +import logging from cProfile import Profile from pstats import Stats import numpy from tqdm.auto import tqdm +from datasets import load_dataset, load_metric from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1 +from deepsparse.transformers.utils.eval_helpers import process_concatenated_datasets -from datasets import load_dataset, load_metric # isort: skip +_LOGGER = logging.getLogger(__name__) -def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"): - if args.max_samples: - batch_size = min(batch_size, args.max_samples) +PPL_DATASETS = ["wikitext2", "c4", "openai_humaneval"] - dataset = load_dataset(dataset_name)["test"] +def perplexity_eval(args, dataset_name="openai_humaneval"): + if dataset_name in ["wikitext2", "c4"]: + if args.kwargs is None: + kwargs = {} + else: + kwargs = json.loads(args.kwargs) + dataset = process_concatenated_datasets( + dataset_name, + args.model_path, + args.max_sequence_length, + kwargs, + ) + # Set perplexity computation to accumulate negative log-likelihood across + # sections + accumulate = True + else: + dataset = load_dataset(dataset_name, split="test") + accumulate = False + + # We'll use the text generation pipeline to generate a single token. + # Along with the token, it returns the logits for input sequence text_generation = Pipeline.create( task="text-generation", model_path=args.model_path, engine_type=args.engine, num_cores=args.num_cores, sequence_length=args.max_sequence_length, - max_generated_tokens=1, + trust_remote_code=args.trust_remote_code, ) - perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size) - active_engines = [ - engine - for engine in [text_generation.engine, text_generation.multitoken_engine] - if engine - ] - print("Engine info: ") - [print(f"{engine}\n") for engine in active_engines] - predictions = [] + + # Instantiate perplexity metric + perplexity_metrics = Perplexity(accumulate=accumulate) + + # Loop through samples + batch_samples = [] + run_inference = False + end_evaluation = False + dataset_length = len(dataset) for idx, sample in _enumerate_progress(dataset, args.max_samples): - predictions.append(sample["prompt"] + sample["canonical_solution"]) - if len(predictions) == batch_size: - perplexity_metrics.add_batch(predictions) - predictions = [] - if args.max_samples and idx >= args.max_samples: + # Collect input sequence + if dataset_name == "openai_humaneval": + sample = sample["prompt"] + sample["canonical_solution"] + batch_samples.append(sample) + + if args.max_samples and idx == args.max_samples - 1: + run_inference = True + end_evaluation = True + + if (idx + 1) % args.batch_size == 0 or idx == dataset_length - 1: + run_inference = True + + if run_inference: + # Perform single token generation + prediction = text_generation( + sequences=batch_samples, + output_scores=True, + return_input_tokens=True, + fixed_sequences_length=True, + include_prompt_logits=True, + max_length=1, + ) + + # Handle one sample at a time to make it simpler for masking + for s in range(len(batch_samples)): + # Need to remove tokens that were masked + input_ids = prediction.input_tokens["input_ids"][s].flatten() + logits = prediction.generations[s].score + attention_mask = prediction.input_tokens["attention_mask"][s].flatten() + + effective_sequence_length = logits.shape[0] + + input_ids = input_ids[-effective_sequence_length:] + attention_mask = attention_mask[-effective_sequence_length:] + + logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :] + input_ids = numpy.compress(attention_mask, input_ids)[1:] + + # Add predictions (logits) and targets (input_ids) to metric + perplexity_metrics.add_batch(logits, input_ids) + + # Reset batch + batch_samples.clear() + run_inference = False + + if end_evaluation: break + return perplexity_metrics @@ -473,7 +536,18 @@ def _split_train_val(train_dataset, val_ratio, seed=42): "imdb": imdb_eval, "conll2003": conll2003_eval, "go_emotions": go_emotions_eval, - "openai_humaneval": perplexity_eval, + "openai_humaneval": lambda args: perplexity_eval( + args, + dataset_name="openai_humaneval", + ), + "wikitext2": lambda args: perplexity_eval( + args, + dataset_name="wikitext2", + ), + "c4": lambda args: perplexity_eval( + args, + dataset_name="c4", + ), } @@ -604,7 +678,24 @@ def parse_args(): type=bool, default=False, ) - + parser.add_argument( + "--batch-size", + help="Batch size with which to evaluate model. Default is 1", + type=int, + default=1, + ) + parser.add_argument( + "--trust-remote-code", + help="Whether to allow for remote code execution in transformers.", + type=bool, + default=False, + ) + parser.add_argument( + "--kwargs", + help="Additional arguments specific to each dataset", + type=str, + default=None, + ) return parser.parse_args() @@ -619,6 +710,12 @@ def _main(args): f"available datasets are {list(SUPPORTED_DATASETS.keys())}" ) + if dataset not in PPL_DATASETS: + _LOGGER.warning( + "Batch-size argument is not supported for this dataset." + "Will use default value of 1." + ) + if dataset == "mnli": mnli_metrics_matched, mnli_metrics_mismatched = mnli_eval(args) mnli_metrics_matched = mnli_metrics_matched.compute() diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index f2e717a08f..1952ec2155 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -16,16 +16,11 @@ Utilities for evaluation metric computation """ - -from itertools import compress -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy -from tqdm import tqdm -from deepsparse import Pipeline -from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline -from deepsparse.transformers.utils.helpers import pad_to_fixed_length +from scipy.special import log_softmax from sklearn.metrics import precision_recall_fscore_support @@ -36,139 +31,107 @@ class Perplexity: - def __init__(self, pipeline: Pipeline, batch_size: int = 16): + def __init__(self, accumulate: bool = False): """ - Given the pipeline, compute the perplexity of the model - on the given text input. + Class for computing perplexity. - Code adapted from: - https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501 + Each batch is processed via the "add_batches" method. + At the end the data is reduced to a single perplexity + metric via the "compute" method. - :param pipeline: The pipeline to use for text generation - :param batch_size: The batch size to split the input text into - non-overlapping batches - """ - torch = _import_torch() - if not isinstance(pipeline, TextGenerationPipeline): - raise ValueError( - "Perplexity can only be computed for text generation pipelines" - ) - self._pipeline = pipeline - self._batch_size = batch_size - self._sequence_length = pipeline.sequence_length - self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - self.perplexities = [] - - def add_batch(self, predictions: List[str]): + Example: + metric = Perplexity() + for prediction, target in samples: + metric.add_batch(prediction, target) + perplexity_value = metric.compute() + + :param accumulate: If True, accumulate negative log-likelihood + over samples. If False, perplexity is computed separately + for each sampled and then averaged in the end. """ - Run the model on the given input sequences and compute the perplexity. - The resulting perplexity is appended to the list of perplexities. + self._predictions = None + self._targets = None + self._accumulate = accumulate + if accumulate: + self._neg_log_likelihood = 0.0 + self._number_tokens = 0 + else: + self._perplexities = None - :param predictions: The predictions to compute perplexity on + def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray): + """ + Computes perplexity or negative log-likelihood for each batch + (depending on accumulate argument) + and track results. + + Tracks perplexity or negative log-likelihood since storing + predictions may require a lot of memory. + + :param predictions: predicted scores. + Accepted shapes: + - [batch_size, sequence_length, vocab_size] + - [sequence_length, vocab_size] (batch size = 1) + Note: sequence length has to be uniform within a batch, but not all + batches require the same sequence length + :param targets: target values - index of correct vocabulary entry """ - torch = _import_torch() - # tokenize the input text - encodings = self._pipeline.tokenizer( - predictions, - return_attention_mask=True, - max_length=self._sequence_length, - truncation=True, - padding="max_length", - ) - encoded_texts = encodings["input_ids"] - attention_masks = encodings["attention_mask"] - - for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)): - end_index = min(start_index + self._batch_size, len(encoded_texts)) - encoded_batch = encoded_texts[start_index:end_index] - attention_mask = attention_masks[start_index:end_index] - - # Computing the ground truth labels - - # `encoded_batch` contains sequences of tokens padded - # with tokens from the left side. We need to remove - # them and zero-pad from the right side up to the length - # of the longest sequence in the batch - - encoded_batch = [ - list(compress(sequence, attn_mask)) - for (sequence, attn_mask) in zip(encoded_batch, attention_mask) - ] - max_sequence_len = max(len(sequence) for sequence in encoded_batch) - - encoded_batch = [ - pad_to_fixed_length(numpy.array(sequence), max_sequence_len) - for sequence in encoded_batch - ] - encoded_batch = numpy.stack(encoded_batch) - - # We need to apply the analogous transformation to the attention mask - attention_mask = numpy.array(attention_mask) - attention_mask = [ - list(filter(lambda num: num != 0, mask)) for mask in attention_mask - ] - attention_mask = [ - pad_to_fixed_length(numpy.array(mask), max_sequence_len) - for mask in attention_mask - ] - attention_mask = numpy.stack(attention_mask) - - labels = encoded_batch - - out = self._pipeline( - sequences=predictions, - return_logits=True, - fixed_sequences_length=True, - include_prompt_logits=True, - ) - - logits = out.logits - - if not self._pipeline.cache_support_enabled: - # when running inference without cache, we need to apply - # analogous transformations to the logits as we did to the labels - # and attention mask - - # remove "nonsensical" logits for tokens - logits = [ - logit[-attn_mask.sum() :, :] - for (logit, attn_mask) in zip(logits, attention_mask) - ] - # pad logits to max length - logits = [ - pad_to_fixed_length(logit, max_sequence_len) for logit in logits - ] - logits = numpy.stack(logits) - - # shift logits and labels create the input and target for the loss function - shift_logits = logits[:, :-1, :] - shift_labels = labels[:, 1:] - shift_attention_mask_batch = attention_mask[:, 1:] - - # compute perplexity for this batch - perplexity_batch = torch.exp( - ( - self._loss_fct( - torch.tensor(shift_logits.transpose(0, 2, 1)), - torch.tensor(shift_labels), - ) - * torch.tensor(shift_attention_mask_batch) - ).sum(1) - / torch.tensor(shift_attention_mask_batch).sum(1) - ) - self.perplexities.extend(perplexity_batch.numpy().tolist()) + if self._accumulate: + # If accumulate is True, every token from the batch contributes + # equally to the negative log-likelihood. + # Thus, merge batch and sequence length dimensions and compute negative + # log-likelihood for all tokens, and accumulate to total + predictions = numpy.reshape(predictions, (-1, predictions.shape[-1])) + targets = targets.flatten() + + # Compute negative log-likelihood and accumulate + self._neg_log_likelihood += _cross_entropy( + predictions, targets, reduction="sum" + ).sum() + + # Track number of tokens processed + self._number_tokens += predictions.shape[0] + else: + # If accumulate is False, compute perplexity for + # each sample individually. + # We assume that sequence length is uniform within a batch, + # but may vary from batch to batch. + + # Create batch dimension if it doesn't exist + if targets.ndim == 1: + predictions = numpy.expand_dims(predictions, axis=0) + targets = numpy.expand_dims(targets, axis=0) + + # Compute negative log-likelihoods for batch + neg_log_likelihoods = _cross_entropy(predictions, targets) + + # Compute perplexities for batch + perplexities = numpy.exp(neg_log_likelihoods) + + # Store perplexities + if self._perplexities is None: + self._perplexities = perplexities + else: + self._perplexities = numpy.concatenate( + (self._perplexities, perplexities) + ) def compute(self) -> Dict[str, Any]: """ - :return: A dictionary containing the mean perplexity - and the list of perplexities + :return: A dictionary containing the final results. + If accumulate is True, return single perplexity. + Else, return a list of perplexities (one for each sample) + and mean perplexity. """ - return { - "mean_perplexity": numpy.mean(self.perplexities), - "perplexities": self.perplexities, - } + + if self._accumulate: + perplexity = numpy.exp(self._neg_log_likelihood / self._number_tokens) + return {"perplexity": perplexity} + else: + return { + "perplexities": self._perplexities, + "mean_perplexity": numpy.mean(self._perplexities), + } class PrecisionRecallF1: @@ -231,19 +194,33 @@ def compute(self) -> Dict[str, float]: return results -def _import_torch(): +def _cross_entropy( + predictions: numpy.ndarray, + targets: numpy.ndarray, + reduction: str = "mean", +) -> float: """ - Import and return the required torch module. Raises an ImportError if torch is not - installed. + Calculate the cross-entropy loss between predicted probabilities and target labels. + + Args: + predictions (numpy.ndarray): Predicted logits. + targets (nnumpy.ndarray): Target class labels. + reduction (str, optional): Specifies the reduction method for the loss. + - "mean" (default): Computes the mean loss over all samples. + - "sum": Computes the sum of losses over all samples. - :raises ImportError: if torch is not installed - :return: torch module + Returns: + float: The computed cross-entropy loss. """ - try: - import torch - - return torch - except ImportError as import_error: - raise ImportError( - "Please install `deepsparse[torch]` to use this pathway" - ) from import_error + + logp = log_softmax(predictions, axis=-1) + neg_log_likelihoods = -1.0 * numpy.take_along_axis( + logp, numpy.expand_dims(targets, axis=-1), axis=-1 + ) + neg_log_likelihoods = numpy.squeeze(neg_log_likelihoods, axis=-1) + if reduction == "mean": + neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1) + elif reduction == "sum": + neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1) + + return neg_log_likelihoods diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 694d11d664..cd6bd6c1f1 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -91,6 +91,10 @@ class Config: alias="prompt", description="The input sequences to generate the text from.", ) + return_input_tokens: bool = Field( + default=False, + description="A flag that indicates whether to return " "the input_tokens. ", + ) include_prompt_logits: bool = Field( default=False, description="A flag that indicates whether to return " @@ -182,6 +186,15 @@ class TextGenerationOutput(BaseModel): "prompt provided. If streamng is enabled, the next generated token is returned." "Otherwise, the full generated sequence is returned." ) + input_tokens: Optional[ + Any + ] = Field( # dictionary mapping "token_ids" and "attention_mask" to numpy arrays + default=None, + description="The output of the tokenizer." + "Dictionary containing token_ids and attention_mask, " + "both mapping to arrays of size " + "[batch_size, sequence_length]", + ) class Config: arbitrary_types_allowed = True @@ -523,6 +536,8 @@ def process_inputs( context = dict( prompts=original_inputs, streaming=inputs.streaming, + return_input_tokens=inputs.return_input_tokens, + input_tokens=input_tokens, generation_config=generation_config, include_prompt_logits=inputs.include_prompt_logits, callback=inputs.callback, @@ -644,8 +659,15 @@ def process_engine_outputs( ] generations = grouped_generations + input_tokens = ( + kwargs.get("input_tokens") if kwargs.get("return_input_tokens") else None + ) + outputs = dict( - created=datetime.datetime.now(), prompts=prompts, generations=generations + created=datetime.datetime.now(), + prompts=prompts, + generations=generations, + input_tokens=input_tokens, ) if self._debug: diff --git a/src/deepsparse/transformers/utils/eval_helpers.py b/src/deepsparse/transformers/utils/eval_helpers.py new file mode 100644 index 0000000000..4c0e68b9de --- /dev/null +++ b/src/deepsparse/transformers/utils/eval_helpers.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Mapping, Union + +import numpy +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from datasets import load_dataset + + +CONCATENATED_DATSETS = ["wikitext2", "c4"] + + +def process_concatenated_datasets( + dataset_name: str, + model_path: str, + max_sequence_length: int, + kwargs: Mapping, +) -> list: + """ + Concatenate text datasets and split them into chunks text that, after + tokenization, have size "max_sequence_length" tokens. + + Args: + dataset_name (str): The name of the dataset to process. + Options: "wikitext2" or "c4". + model_path (str): The path to a pretrained transformer model for tokenization. + max_sequence_length (int): The maximum number of tokens in each sequence. + kwargs (mapping): Additional keyword arguments. + - eos (str, optional): The end-of-sentence token. + Default is "\n\n" for wikitext2 and "" for c4. + - bos (str, optional): The beginning-of-sentence token. + Default is "". + - raw_samples (int, optional): The number of raw samples to use. + Default is None. + - data_file (int, optional): The index of the data file to use for dataset. + Not used in wikitext2. Default is 0 for c4. + - max_text_length (int, optional): The maximum length of text to consider. + Returns: + list: A list of text sequences. + + Raises: + ValueError: If an invalid dataset_name is provided. + """ + + if dataset_name not in CONCATENATED_DATSETS: + raise KeyError( + f"dataset {dataset_name} not supported for concatenated processing, " + f"available datasets are {list(CONCATENATED_DATSETS.keys())}" + ) + + if dataset_name == "wikitext2": + eos = kwargs.get("eos", "\n\n") + bos = kwargs.get("bos", "") + + raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + raw_text = raw_dataset["text"] + elif dataset_name == "c4": + eos = kwargs.get("eos", "<|endoftext|>") + bos = kwargs.get("bos", "") + raw_samples = kwargs.get("raw_samples", None) + data_file = kwargs.get("data_file", 0) + if data_file is not None: + raw_dataset = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={ + "validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz" + }, + split="validation", + ) + else: + raw_dataset = load_dataset( + "allenai/c4", + "allenai--c4", + split="validation", + ) + if raw_samples is not None: + raw_dataset = raw_dataset[:raw_samples] + raw_text = raw_dataset["text"] + + # Dataset is split into sections that contain "max_sequence_length" tokens. + # To split the dataset, first tokenize text + tokenizer = AutoTokenizer.from_pretrained(model_path) + return _split_text_by_tokens( + raw_text, + eos, + bos, + tokenizer, + max_sequence_length, + kwargs.get("max_text_length", None), + ) + + +def _split_text_by_tokens( + text: List[str], + eos: str, + bos: str, + tokenizer: PreTrainedTokenizerFast, + sequence_length: int, + max_text_length: Union[None, int], +) -> List[str]: + """ + Tokenizes and splits a list of concatenated text samples into + sections of specified maximum token length. + + Args: + text (List[str]): List of concatenated text samples to be tokenized and split. + eos (str): The end-of-sentence token. + bos (str): The beginning-of-sentence token. + tokenizer (PreTrainedTokenizerFast): Tokenizer for tokenizing the text. + sequence_length (int): The maximum number of tokens in each section. + max_text_length (Union[None, int]): The maximum length of text to consider. + - If None, the entire text is tokenized and split. + - If -1, each sample is tokenized separately. + - If a positive integer, the text is split into sections of this + length before tokenization. + + Returns: + List[str]: A list of sections where each section contains a + maximum of "sequence_length" tokens. + """ + + text = [bos + sample + eos for sample in text] + + if max_text_length is None: + text = "".join(text) + input_tokens = tokenizer(text, return_tensors="np")["input_ids"][0] + elif max_text_length == -1: # per sample tokenization + input_tokens = [] + for slice in text: + input_tokens.append(tokenizer(slice, return_tensors="np")["input_ids"][0]) + input_tokens = numpy.concatenate(input_tokens) + else: + text = "".join(text) + text_slices = len(text) // max_text_length + sliced_text = [ + text[i * max_text_length : (i + 1) * max_text_length] + for i in range(text_slices) + ] + sliced_text.append(text[text_slices * max_text_length :]) + input_tokens = [] + for slice in sliced_text: + input_tokens.append(tokenizer(slice, return_tensors="np")["input_ids"][0]) + input_tokens = numpy.concatenate(input_tokens) + + # Then split the tokenized text into sections of size "max_sequence_length" and + # decode each section back into text format + split_text = [] + for i in range(len(input_tokens) // sequence_length): + start = i * sequence_length + end = (i + 1) * sequence_length + split_text.append( + tokenizer.decode( + input_tokens[start:end], + clean_up_tokenization_spaces=False, + ) + ) + + # Handle any leftover tokens + if (i + 1) * sequence_length < len(input_tokens): + start = (i + 1) * sequence_length + end = len(input_tokens) + split_text.append( + tokenizer.decode( + input_tokens[start:end], + clean_up_tokenization_spaces=False, + ) + ) + + return split_text diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 38e3ec4a4c..23b8244b71 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -322,7 +322,7 @@ def pad_to_fixed_length( ) -> numpy.ndarray: """ Pads the array to a fixed length along the given axis. - The padding is done on the right side of the array. + The padding is done on the left side of the array. :param array: array to pad :param max_len: maximum length to pad to @@ -334,7 +334,7 @@ def pad_to_fixed_length( padding = [(0, 0)] * len(array.shape) # for the specified axis, pad to the max length # (from the right side of the array) - padding[axis] = (0, max_len - array.shape[axis]) + padding[axis] = (max_len - array.shape[axis], 0) return numpy.pad(array, padding, mode="constant", constant_values=value) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 5298c2f1dd..fb25a33883 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -13,675 +13,120 @@ # limitations under the License. import inspect -from typing import List, Optional, Tuple import numpy -from transformers import GenerationConfig import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.helpers import prepends_bos_token -from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource - - -_PRECISION = 1e-3 - -NATURAL_LANGUAGE_PROMPT = """ -Didn't know what time it was, the lights were low -I leaned back on my radio -Some cat was layin' down some rock 'n' roll -"Lotta soul," he said -Then the loud sound did seem to fade -Came back like a slow voice on a wave of phase -That weren't no DJ, that was hazy cosmic jive -""" - -CODE_LANGUAGE_PROMPT = """ -def Fibonacci(n): - # Check if input is 0 then it will - # print incorrect input - if n < 0: - print("Incorrect input") - # Check if n is 0 - # then it will return 0 - elif n == 0: - return 0 -""" - - -@pytest.mark.parametrize( - "internal_kv_cache", - [ - True, - False, - ], -) -@pytest.mark.parametrize( - "pipeline_type", - ["text_generation", "chat"], -) -@pytest.mark.parametrize( - "model_stub, " - "model_name, " - "uses_bos_token, " - "prompt, " - "logits_max_diff_kv_cache_has_been_filled", - [ - ( - "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - "salesforce/codegen-350m-mono", - False, - CODE_LANGUAGE_PROMPT, - 13, - ), - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" - "opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - NATURAL_LANGUAGE_PROMPT, - 3.9, - ), - ], - scope="class", -) -@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") -class TestTextGenerationPipeline: - """ - This test suite is meant to test the main scenarios of - the text generation pipeline. - """ - - def get_pipeline(self, **kwargs): - if not kwargs: - # return the default pipeline - if self.default_pipeline: - return self.default_pipeline - else: - self.default_pipeline = Pipeline.create( - task=self.pipeline_type, - model_path=self.model_stub, - internal_kv_cache=self.internal_kv_cache, - prompt_sequence_length=self.prompt_sequence_length, - sequence_length=self.sequence_length, - ) - return self.default_pipeline - # return a pipeline with the given kwargs - return Pipeline.create(**kwargs) - - @pytest.fixture - def setup( - self, - model_stub, - model_name, - uses_bos_token, + + +@pytest.fixture +def pipeline(): + return Pipeline.create( + task="text_generation", + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ) + + +@pytest.fixture +def prompt(): + return "Never gonna give you up, never gonna let you down" + + +def test_freeze_first_position(pipeline): + # Test whether we should be "freezing" the first token after + # the kv cache is full + assert not prepends_bos_token(pipeline.tokenizer) + + +def test_run_same_prompt_multiple_times(pipeline, prompt): + # Test the scenario, where the same prompt is run multiple times + # Every run should produce the same output + output_1 = pipeline(prompt, output_scores=True) + output_2 = pipeline(prompt, output_scores=True) + + assert output_1.generations[0].text == output_2.generations[0].text + assert numpy.allclose( + output_1.generations[0].score, + output_2.generations[0].score, + atol=1e-3, + ) + + +def test_run_multiple_prompts_in_parallel(pipeline, prompt): + # Test the scenario, where multiple prompts are run in parallel + # Same two prompts should produce the same output + + output = pipeline([prompt, prompt], output_scores=True) + + logits_0 = output.generations[0].score + sequence_0 = output.generations[0].text + + logits_1 = output.generations[1].score + sequence_1 = output.generations[1].text + + assert numpy.allclose(logits_0, logits_1, atol=1e-3) + assert sequence_0 == sequence_1 + + +def test_num_generated_predictions(pipeline, prompt): + # Test the scenario, where multiple predictions are generated + # from the same prompt + + output_sequences = pipeline(prompt, num_return_sequences=2) + + assert len(output_sequences.generations) == 1 + assert len(output_sequences.generations[0]) == 2 + + output_sequences = pipeline([prompt, prompt], num_return_sequences=2) + assert len(output_sequences.generations) == 2 + + for generation in output_sequences.generations: + assert len(generation) == 2 + + +def test_token_generation_deterministic(pipeline, prompt): + inference = pipeline(prompt, num_return_sequences=3, do_sample=False) + generations = inference.generations + # Output should be the same from one another + text_outputs = [x.text for x in generations[0]] + assert len(set(text_outputs)) == 1 + + +def test_token_generation_non_deterministic(pipeline, prompt): + + inference = pipeline(prompt, num_return_sequences=3, do_sample=True) + generations = inference.generations + # Output should be different from one another + text_outputs = [x.text for x in generations[0]] + assert len(set(text_outputs)) == 3 + + +def test_pipeline_for_ppl_eval(pipeline, prompt): + predictions = pipeline( prompt, - logits_max_diff_kv_cache_has_been_filled, - internal_kv_cache, - pipeline_type, - ): - self.num_tokens_generate = 216 - self.model_stub = model_stub - self.prompt = prompt - self.pipeline_type = pipeline_type - # create torch ground source - torch_source = TorchGroundTruthSource( - num_tokens_to_generate=self.num_tokens_generate, model_name=model_name - ) - torch_ground_truth = torch_source(self.prompt) - - # prompt length is expressed in number of prompt tokens - prompt_length = torch_ground_truth[1].shape[1] - - # sequence_length that assures that the KV cache will not be filled up - self.sequence_length = 2 * prompt_length + self.num_tokens_generate - # sequence_length that assures that the KV cache will be filled up - self.sequence_length_short = self.num_tokens_generate - - # prompt_sequence_length used for the multitoken prefill scenario - self.prompt_sequence_length = prompt_length // 2 - - # the maximum threshold for the difference between the logits - # when running a scenario where KV Cache buffer has been filled - self.logits_max_diff_kv_cache_has_been_filled = ( - logits_max_diff_kv_cache_has_been_filled - ) - self.internal_kv_cache = internal_kv_cache - - self.default_pipeline = None - - assert self.prompt_sequence_length < prompt_length, ( - "The prompt processing sequence length " - "must be smaller than the prompt length" - ) - - yield model_name, uses_bos_token, torch_ground_truth - - def test_freeze_first_position(self, setup): - # Test whether we should be "freezing" the first token after - # the kv cache is full - _, uses_bos_token, _ = setup - pipeline = self.get_pipeline() - assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token - - def test_ort_single_token_prefill(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by single-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - engine_type="onnxruntime", - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - ) - - def test_ort_multi_token_prefill(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, - engine_type="onnxruntime", - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - ) - - def test_ort_generation_after_kv_cache_has_been_filled(self, setup): - # Test the pipeline that uses ORT engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed externally - - if self.internal_kv_cache: - pytest.skip( - "Cannot run ORT pipeline with the internal deepsparse cache enabled." - ) - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, - engine_type="onnxruntime", - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) - - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 - ) - - def test_deepsparse_single_token_prefill(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by single-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - ) - - def test_deepsparse_multi_token_prefill(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is never filled up - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] < self.sequence_length - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - ) - - def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the - # following scenario: - # 1. Prompt preprocessing is performed by multi-token engine - # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed externally or internally - - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, - ) - pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) - - self._test_output( - output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 - ) - - def test_run_same_prompt_multiple_times(self, setup): - # Test the scenario, where the same prompt is run multiple times - # Every run should produce the same output - pipeline = self.get_pipeline() - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output_1 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - output_2 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - assert output_1.generations[0].text == output_2.generations[0].text - assert numpy.allclose( - output_1.generations[0].score, - output_2.generations[0].score, - atol=_PRECISION, - ) - - def test_run_multiple_prompts_in_parallel(self, setup): - # Test the scenario, where multiple prompts are run in parallel - # Same two prompts should produce the same output - pipeline = self.get_pipeline() - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=[self.prompt, self.prompt], - generation_config=config, - include_prompt_logits=True, - ) - - logits_0 = output.generations[0].score - sequence_0 = output.generations[0].text - - logits_1 = output.generations[1].score - sequence_1 = output.generations[1].text - - assert numpy.allclose(logits_0, logits_1, atol=_PRECISION) - assert sequence_0 == sequence_1 - - def test_num_generated_predictions(self, setup): - # Test the scenario, where multiple predictions are generated - # from the same prompt - pipeline = self.get_pipeline() - - config = GenerationConfig( - num_return_sequences=2, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - ) - - output_sequences = pipeline(sequences=[self.prompt], generation_config=config) - assert len(output_sequences.generations) == 1 - assert len(output_sequences.generations[0]) == 2 - - output_sequences = pipeline( - sequences=[self.prompt, self.prompt], generation_config=config - ) - assert len(output_sequences.generations) == 2 - - for generation in output_sequences.generations: - assert len(generation) == 2 - - def test_token_generation_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=False, - ) - pipeline = self.get_pipeline(**pipeline_kwargs) - inference = pipeline(sequences=["hello?"], generation_config=config) - generations = inference.generations - text_outputs = [x.text for x in generations[0]] - assert len(set(text_outputs)) == 1 - - def test_token_generation_non_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - pipeline = self.get_pipeline(**pipeline_kwargs) - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=True, - ) - inference = pipeline(sequences=["hello?"], generation_config=config) - generations = inference.generations - # Output should be the same from one another - text_outputs = [x.text for x in generations[0]] - assert len(set(text_outputs)) == 3 - - def test_run_with_same_session_ids(self, setup): - # Test the scenario where the same session ids are used for multiple - # inference runs. There are two conditions that must be fulfilled: - # 1. The information regarding the prompt does not leak between sessions - # 2. Running two prompts one after another is identical to running - # a composition of those prompts i.e. - # generated_text = pipeline(prompt_1) - # generated_text_2 = pipeline(prompt_2) - # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) - - if self.pipeline_type not in ["chatbot", "chat"]: - pytest.skip("This test is only applicable to chatbot pipeline") - - prompt_1 = "This prompt is used for testing purposes. To this to make sure that" - prompt_2 = "still this prompt should not" - num_generated_tokens = 32 - - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=False, - ) - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=True, - ) - - def _test_run_with_same_session_ids( - self, - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill, - ): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - prompt_sequence_length=self.prompt_sequence_length - if multi_token_prefill - else 1, - force_max_tokens=True, - internal_kv_cache=self.internal_kv_cache, - ) - - # make sure information does not leak between sessions - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_1", - session_id_2="test_2", - ) - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_3", - session_id_2="test_4", - ) - - @staticmethod - def _test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1, - session_id_2, - ): - - tokenizer = pipeline.tokenizer - config = GenerationConfig( - output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 - ) - - # make sure that running two prompts one after another - # is identical to running a composition of those prompts - out_1_ = pipeline( - sequences=prompt_1, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - prompt_1_ = out_1_.generations[0].text - out_1 = pipeline( - sequences=prompt_2, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ - "past_key_values.0.key" - ] - - prompt_composition = tokenizer.decode( - tokenizer(prompt_1).input_ids - + tokenizer(prompt_1_).input_ids - + tokenizer(prompt_2).input_ids, - skip_special_tokens=True, - ) - out_2 = pipeline( - sequences=prompt_composition, - session_ids=session_id_2, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ - "past_key_values.0.key" - ] - if cache_state_1.shape[0]: - # if cache state is not empty, i.e. we are managing kv cache - # externally, make sure that the cache state is the same - numpy.allclose(cache_state_1, cache_state_2, atol=_PRECISION) - assert out_1.generations[0].text == out_2.generations[0].text - - def _test_output( - self, - output: "TextGenerationOutput", # noqa F821 - torch_ground_truth: Tuple[numpy.ndarray, ...], - max_logits_difference_threshold: Optional[float] = None, - run_cache_validation: bool = True, - ): - - ( - generated_logits, - prompt_logits, - prompt_kv_cache, - generated_text, - ) = torch_ground_truth - - # concatenate target prompt_logits and generated_logits and check - target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) - score = output.generations[0].score - - if max_logits_difference_threshold: - # if comparing the output from the model where - # the kv cache has been filled, we expect the - # maximum absolute difference between the logits - # to be less than the threshold - # (the threshold is established by running the - # ONNX model in ONNXRuntime) - assert abs(score - target_logits[0]).max() < max_logits_difference_threshold - else: - # otherwise, we expect the logits to be exactly the same - # as the target logits; the generated sequence should - # also be the same as the target sequence, and finally - # (if applicable) the kv cache should be the same as the - # target kv cache - - assert numpy.allclose(score, target_logits[0], atol=_PRECISION) - assert self.prompt + output.generations[0].text == generated_text - - if run_cache_validation: - # extract numpy arrays from cached_inputs - kv_cache_array = list(output.kv_cache_state[0].values()) - total_num_processed_tokens = output.total_num_processed_tokens[0] - self._test_kv_cache_state( - expected_cache=kv_cache_array, - target_cache=torch_ground_truth[2], - total_num_processed_tokens=total_num_processed_tokens, - ) - - @staticmethod - def _test_kv_cache_state( - expected_cache: List[numpy.ndarray], - target_cache: List[numpy.ndarray], - total_num_processed_tokens: int, - ): - for x, y in zip(expected_cache, target_cache): - start_index = total_num_processed_tokens - end_index = total_num_processed_tokens - y.shape[2] - # x is (in general) composed of three arrays: - # - padding cache entries (from 0 to -start_index) - # - prompt cache entries (from -start_index to -end_index) - # - generated cache entries (from -end_index to -1) - # as target_cache only pertains to prompt cache entries, we need to - # compare only the prompt cache entries in x with y - assert numpy.allclose( - x[:, :, -start_index:-end_index, :], y, atol=_PRECISION - ) - - def test_streaming_mode_returns_generator(self, setup): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - ) - inputs = dict(prompt=self.prompt, streaming=True) - response_generator = pipeline(**inputs) - - assert inspect.isgenerator( - response_generator - ), "Pipeline should return a generator in streaming mode" - - assert all( - isinstance(response, pipeline.output_schema) - for response in response_generator - ), "Pipeline should return a generator of output_schema \ - objects in streaming mode" + output_scores=True, + return_input_tokens=True, + fixed_sequences_length=True, + include_prompt_logits=True, + max_length=1, + ) + assert hasattr(predictions, "generations") + assert hasattr(predictions, "input_tokens") + assert hasattr(predictions.generations[0], "score") + assert "input_ids" in predictions.input_tokens + assert "attention_mask" in predictions.input_tokens + + +def test_streaming_mode_returns_generator(pipeline, prompt): + response_generator = pipeline(prompt, streaming=True) + assert inspect.isgenerator( + response_generator + ), "Pipeline should return a generator in streaming mode" + + assert all( + isinstance(response, pipeline.output_schema) for response in response_generator + ), "Pipeline should return a generator of output_schema \ + objects in streaming mode"