From 21c6f0ddf5e349fd6a2891544d47848f57bd5772 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 8 Nov 2023 10:26:48 -0500 Subject: [PATCH] Re-add unit test --- .../pipelines/test_text_generation.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index c70c50a5ef..1a408fb92b 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -104,6 +104,29 @@ def test_token_generation_non_deterministic(pipeline, prompt): assert len(set(text_outputs)) == 3 +def test_pipeline_for_ppl_eval(self, ): + pipeline = self.get_pipeline( + task="text-generation", + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=1, + ) + inputs = dict( + prompt=self.prompt, + output_scores=True, + return_input_tokens=True, + fixed_sequences_length=True, + include_prompt_logits=True, + max_length=1, + ) + predictions = pipeline(**inputs) + assert hasattr(predictions, "generations") + assert hasattr(predictions.generations[0], "score") + assert hasattr(predictions.generations[0], "input_tokens") + assert "input_ids" in predictions.generations[0].input_tokens + assert "attention_mask" in predictions.generations[0].input_tokens + + def test_streaming_mode_returns_generator(pipeline, prompt): response_generator = pipeline(prompt, streaming=True) assert inspect.isgenerator(