From 71b0e74ea668de4ecada765c3c3ab6756d7270fc Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 17 Oct 2023 12:18:07 +0000 Subject: [PATCH 1/5] initial commit --- .../pipelines/configs/__init__.py | 13 + .../pipelines/configs/gpt_neo.yaml | 13 + .../transformers/pipelines/helpers.py | 49 +- .../transformers/pipelines/test_chat.py | 173 ++++- .../pipelines/test_text_generation.py | 645 +++++++----------- 5 files changed, 453 insertions(+), 440 deletions(-) create mode 100644 tests/deepsparse/transformers/pipelines/configs/__init__.py create mode 100644 tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml diff --git a/tests/deepsparse/transformers/pipelines/configs/__init__.py b/tests/deepsparse/transformers/pipelines/configs/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/configs/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml b/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml new file mode 100644 index 0000000000..df209b5bee --- /dev/null +++ b/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml @@ -0,0 +1,13 @@ +model_path: "hf:mgoin/TinyStories-1M-deepsparse" +model_name: "roneneldan/TinyStories-1M" +pipeline_type: ["text-generation", "chat"] +num_tokens_generate: 128 +prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio\n Some cat was layin' down some rock 'n' roll\n \"Lotta soul,\" he said\n Then the loud sound did seem to fade\n Came back like a slow voice on a wave of phase\n That weren't no DJ, that was hazy cosmic jive" +has_bos_token: False +logits_threshold: 24.7 +precision: 0.001 +cache_management_type: + - "internal" + - "external" +run_helper_tests: True +cadence: "commit" \ No newline at end of file diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py index 0bb962a8e3..e22f5d23ff 100644 --- a/tests/deepsparse/transformers/pipelines/helpers.py +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +import functools +import os +from typing import Dict, List, Optional, Tuple, Union import numpy +import yaml from transformers import AutoModelForCausalLM, AutoTokenizer +import pytest + class TorchGroundTruthSource: """ @@ -82,3 +87,45 @@ def _create_tokenizer(model_name): tokenizer.pad_token = tokenizer.eos_token return tokenizer + + +def parse_params(config_path: str) -> Tuple[Optional[Dict], Optional[str]]: + # parses the config file provided + assert os.path.isfile(config_path), f"config_path {config_path} is not a file" + # reads the yaml file + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + cadence = os.environ.get("CADENCE", "commit") + if cadence == config["cadence"]: + return config, None + return None, "Skipping test for cadence: {}".format(config["cadence"]) + + +def validate_cache_management_type( + internal_kv_cache, cache_management_type: Union[str, List[str]] +) -> bool: + if internal_kv_cache and "internal" not in cache_management_type: + pytest.skip( + "The tests for running the pipeline with " + "internal kv cache management are disabled." + ) + if not internal_kv_cache and "external" not in cache_management_type: + pytest.skip( + "The tests for running the pipeline with " + "external kv cache management are disabled." + ) + return internal_kv_cache + + +def helper_test(test_method): + @functools.wraps(test_method) + def wrapper(self, setup): + if not self.run_helper_tests: + raise pytest.skip( + "Skipping the helper test. Set run_helper_tests to True to run it." + ) + + return test_method(self, setup) + + return wrapper diff --git a/tests/deepsparse/transformers/pipelines/test_chat.py b/tests/deepsparse/transformers/pipelines/test_chat.py index 2a6b5d1ebf..4da1c66726 100644 --- a/tests/deepsparse/transformers/pipelines/test_chat.py +++ b/tests/deepsparse/transformers/pipelines/test_chat.py @@ -12,34 +12,157 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +from transformers import GenerationConfig + import pytest -from deepsparse import Pipeline - - -@pytest.mark.parametrize( - "pipeline_kwargs", - [ - dict( - model_path="zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - engine_type="onnxruntime", - ), - ], +from tests.deepsparse.transformers.pipelines.helpers import helper_test +from tests.deepsparse.transformers.pipelines.test_text_generation import ( + TestTextGenerationPipeline, ) -@pytest.mark.skip(reason="too heavy for now to run in gha") -def test_chat_pipeline_session_manager(pipeline_kwargs): - chat_pipeline = Pipeline.create(task="chat", **pipeline_kwargs) - with chat_pipeline.session(): - output_1 = chat_pipeline( - prompt="first", generation_config=dict(max_new_tokens=1) + +@pytest.fixture +def config(request): + return request.param + + +class TestChatPipeline(TestTextGenerationPipeline): + @pytest.fixture + def pipeline_type(self): + return "chat" + + @helper_test + def test_chat_pipeline_session_manager(self, setup): + pipeline = self.get_pipeline() + + with pipeline.session(): + output_1 = pipeline( + prompt="first", generation_config=dict(max_new_tokens=1) + ) + output_2 = pipeline( + prompt="second", generation_config=dict(max_new_tokens=1) + ) + # assert inferences in the same context share a session id + assert output_1.session_ids == output_2.session_ids + + # test that follow-up inference has a different session id + output_3 = pipeline(prompt="third", generation_config=dict(max_new_tokens=1)) + assert output_3.session_ids != output_1.session_ids + + @helper_test + def test_run_with_same_session_ids(self, setup): + # Test the scenario where the same session ids are used for multiple + # inference runs. There are two conditions that must be fulfilled: + # 1. The information regarding the prompt does not leak between sessions + # 2. Running two prompts one after another is identical to running + # a composition of those prompts i.e. + # generated_text = pipeline(prompt_1) + # generated_text_2 = pipeline(prompt_2) + # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) + + prompt_1 = "This prompt is used for testing purposes. To this to make sure that" + prompt_2 = "still this prompt should not" + num_generated_tokens = 32 + + self._test_run_with_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill=False, + ) + self._test_run_with_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill=True, + ) + + def _test_run_with_same_session_ids( + self, + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill, + ): + pipeline = self.get_pipeline( + prompt_sequence_length=self.prompt_sequence_length + if multi_token_prefill + else 1, + ) + + # make sure information does not leak between sessions + self._test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1="test_1", + session_id_2="test_2", + ) + + self._test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1="test_3", + session_id_2="test_4", + ) + + def _test_composition_same_session_ids( + self, + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1, + session_id_2, + ): + + tokenizer = pipeline.tokenizer + config = GenerationConfig( + output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 + ) + + # make sure that running two prompts one after another + # is identical to running a composition of those prompts + out_1_ = pipeline( + sequences=prompt_1, + force_max_tokens=True, + session_ids=session_id_1, + generation_config=config, + include_prompt_logits=True, ) - output_2 = chat_pipeline( - prompt="second", generation_config=dict(max_new_tokens=1) + prompt_1_ = out_1_.generations[0].text + out_1 = pipeline( + sequences=prompt_2, + force_max_tokens=True, + session_ids=session_id_1, + generation_config=config, + include_prompt_logits=True, ) - # assert inferences in the same context share a session id - assert output_1.session_ids == output_2.session_ids + cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ + "past_key_values.0.key" + ] - # test that follow-up inference has a different session id - output_3 = chat_pipeline(prompt="third", generation_config=dict(max_new_tokens=1)) - assert output_3.session_ids != output_1.session_ids + prompt_composition = tokenizer.decode( + tokenizer(prompt_1).input_ids + + tokenizer(prompt_1_).input_ids + + tokenizer(prompt_2).input_ids, + skip_special_tokens=True, + ) + out_2 = pipeline( + sequences=prompt_composition, + session_ids=session_id_2, + generation_config=config, + include_prompt_logits=True, + ) + cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ + "past_key_values.0.key" + ] + if cache_state_1.shape[0]: + # if cache state is not empty, i.e. we are managing kv cache + # externally, make sure that the cache state is the same + np.allclose(cache_state_1, cache_state_2, atol=self.precision) + assert out_1.generations[0].text == out_2.generations[0].text diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 5298c2f1dd..b8a3e4ff4b 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -11,7 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +A sample config file requires the following arguments: + model_path: The path to the model to be tested + (sparsezoo stub/hf model path/local_path) + model_name: The name of the hugging face model + (to generate ground truth info) + pipeline_type: The type of the pipeline to be tested + (e.g. text-generation) + num_tokens_generate: The number of tokens to generate + prompt: The prompt to use for testing + has_bos_token: Whether the model has a bos token + logits_threshold: The treshold for the max difference between the + actual and the expected logits in the situations + where they will not be able to match the ground + truth (e.g. when the DeepSparse pipeline is + running after the KV cache has been filled up). + This value is established + empirically for each combination of + prompt/pipeline/num_generated tokens. + precision: The precision for the logits/kv_cache entries comparison + cache_management_type: The type of cache management to be tested. + The available options are: "internal" and "external". + run_helper_tests: Whether to run the helper test for the pipeline. Helper tests + check functionalities of the pipeline that are not directly + on the hot path. + cadence: The cadence of the tests. The available options are: + "nightly" and "commit". By default, only the tests that have cadence + "commit" will be run in GHA. +""" import inspect from typing import List, Optional, Tuple @@ -20,150 +48,149 @@ import pytest from deepsparse import Pipeline +from deepsparse.transformers.pipelines.text_generation import TextGenerationOutput from deepsparse.transformers.utils.helpers import prepends_bos_token -from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource +from tests.deepsparse.transformers.pipelines.helpers import ( + TorchGroundTruthSource, + helper_test, + parse_params, + validate_cache_management_type, +) -_PRECISION = 1e-3 +# the user can specify the config file to be used for the tests +# TODO: add more configs +# TODO: add explanation +AVAILABLE_CONFIGS = [ + "tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml", + # "tests/deepsparse/transformers/pipelines/configs/text_generation_opt.yaml", + # "tests/deepsparse/transformers/pipelines/configs/text_generation_codegen.yaml", +] -NATURAL_LANGUAGE_PROMPT = """ -Didn't know what time it was, the lights were low -I leaned back on my radio -Some cat was layin' down some rock 'n' roll -"Lotta soul," he said -Then the loud sound did seem to fade -Came back like a slow voice on a wave of phase -That weren't no DJ, that was hazy cosmic jive -""" -CODE_LANGUAGE_PROMPT = """ -def Fibonacci(n): - # Check if input is 0 then it will - # print incorrect input - if n < 0: - print("Incorrect input") - # Check if n is 0 - # then it will return 0 - elif n == 0: - return 0 -""" +@pytest.fixture +def config(request): + return request.param +@pytest.mark.parametrize("config", AVAILABLE_CONFIGS, indirect=["config"]) @pytest.mark.parametrize( "internal_kv_cache", - [ - True, - False, - ], -) -@pytest.mark.parametrize( - "pipeline_type", - ["text_generation", "chat"], -) -@pytest.mark.parametrize( - "model_stub, " - "model_name, " - "uses_bos_token, " - "prompt, " - "logits_max_diff_kv_cache_has_been_filled", - [ - ( - "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - "salesforce/codegen-350m-mono", - False, - CODE_LANGUAGE_PROMPT, - 13, - ), - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" - "opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - NATURAL_LANGUAGE_PROMPT, - 3.9, - ), - ], - scope="class", + [True, False], ) -@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of the text generation pipeline. """ - def get_pipeline(self, **kwargs): + @pytest.fixture + def pipeline_type(self): + return "text-generation" + + def get_pipeline(self, **kwargs) -> Pipeline: + """ + If no kwargs provided, returns the cached "default" + pipeline that is used for most of the tests. + Otherwise, returns a pipeline with the given kwargs + (the default pipeline kwargs are updated with the + user-provided kwargs) + + :param kwargs: the optional kwargs to be used to + create the pipeline (if not provided, the cached + "default" pipeline is returned) + :return: the appropriate pipeline + """ if not kwargs: - # return the default pipeline - if self.default_pipeline: - return self.default_pipeline - else: - self.default_pipeline = Pipeline.create( - task=self.pipeline_type, - model_path=self.model_stub, - internal_kv_cache=self.internal_kv_cache, - prompt_sequence_length=self.prompt_sequence_length, - sequence_length=self.sequence_length, - ) - return self.default_pipeline - # return a pipeline with the given kwargs - return Pipeline.create(**kwargs) + if self.default_pipeline is None: + self.default_pipeline = Pipeline.create(**self.default_pipeline_kwargs) + return self.default_pipeline + + # return a pipeline with the updated default kwargs + updated_kwargs = self.default_pipeline_kwargs.copy() + updated_kwargs.update(kwargs) + return Pipeline.create(**updated_kwargs) + + def run_pipeline(self, pipeline: Pipeline, **kwargs) -> TextGenerationOutput: + """ + Run the pipeline and return the output + + :param pipeline: the pipeline to be run + :param kwargs: the optional kwargs to be used to + run the pipeline + :return: the pipeline output + """ + sequences = kwargs.get("sequences", self.prompt) + num_return_sequences = kwargs.get("num_return_sequences", 1) + do_sample = kwargs.get("do_sample", False) + streaming = kwargs.get("streaming", False) + + config = GenerationConfig( + output_scores=True, + max_length=self.num_tokens_generate, + top_k=0, + top_p=0.0, + num_return_sequences=num_return_sequences, + do_sample=do_sample, + ) + return pipeline( + sequences=sequences, + force_max_tokens=True, + include_prompt_logits=True, + generation_config=config, + streaming=streaming, + ) @pytest.fixture - def setup( - self, - model_stub, - model_name, - uses_bos_token, - prompt, - logits_max_diff_kv_cache_has_been_filled, - internal_kv_cache, - pipeline_type, - ): - self.num_tokens_generate = 216 - self.model_stub = model_stub - self.prompt = prompt - self.pipeline_type = pipeline_type + def setup(self, config, internal_kv_cache, pipeline_type): + params_dict, skip_reason = parse_params(config) + if params_dict is None: + # skip the test if the config file is not available + pytest.skip(skip_reason) + # set the params_dict as the class attributes + for key, value in params_dict.items(): + setattr(self, key, value) + # check whether the internal kv cache is supported for testing + # (skip if not supported) + self.internal_kv_cache: bool = validate_cache_management_type( + internal_kv_cache, self.cache_management_type + ) + # check whether the pipeline_type is supported for testing + # (skip if not supported) + if pipeline_type not in self.pipeline_type: + pytest.skip( + f"Pipeline type: {self.pipeline_type} " + f"does not match the current type: {pipeline_type}" + ) + # create torch ground source torch_source = TorchGroundTruthSource( - num_tokens_to_generate=self.num_tokens_generate, model_name=model_name + num_tokens_to_generate=self.num_tokens_generate, model_name=self.model_name ) - torch_ground_truth = torch_source(self.prompt) - - # prompt length is expressed in number of prompt tokens - prompt_length = torch_ground_truth[1].shape[1] + # create torch ground truth + self.torch_ground_truth = torch_source(self.prompt) + prompt_length = self.torch_ground_truth[1].shape[1] # sequence_length that assures that the KV cache will not be filled up self.sequence_length = 2 * prompt_length + self.num_tokens_generate # sequence_length that assures that the KV cache will be filled up self.sequence_length_short = self.num_tokens_generate - # prompt_sequence_length used for the multitoken prefill scenario - self.prompt_sequence_length = prompt_length // 2 - - # the maximum threshold for the difference between the logits - # when running a scenario where KV Cache buffer has been filled - self.logits_max_diff_kv_cache_has_been_filled = ( - logits_max_diff_kv_cache_has_been_filled - ) - self.internal_kv_cache = internal_kv_cache - - self.default_pipeline = None - + # prompt_sequence_length used for the multi-token prefill scenario + self.prompt_sequence_length = prompt_length // 4 assert self.prompt_sequence_length < prompt_length, ( "The prompt processing sequence length " "must be smaller than the prompt length" ) - yield model_name, uses_bos_token, torch_ground_truth - - def test_freeze_first_position(self, setup): - # Test whether we should be "freezing" the first token after - # the kv cache is full - _, uses_bos_token, _ = setup - pipeline = self.get_pipeline() - assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token + self.default_pipeline_kwargs = dict( + task=pipeline_type, + model_path=self.model_path, + internal_kv_cache=self.internal_kv_cache, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + ) + self.default_pipeline = None def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -176,27 +203,21 @@ def test_ort_single_token_prefill(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, prompt_sequence_length=1, engine_type="onnxruntime", ) pipeline._debug = True + output = self.run_pipeline(pipeline) - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config + assert output.total_num_processed_tokens[0] < self.sequence_length, ( + "The total number of processed tokens must be smaller than the " + "sequence length" ) - assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=torch_ground_truth, + torch_ground_truth=self.torch_ground_truth, ) def test_ort_multi_token_prefill(self, setup): @@ -210,26 +231,16 @@ def test_ort_multi_token_prefill(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - _, _, torch_ground_truth = setup pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, engine_type="onnxruntime", ) pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output = self.run_pipeline(pipeline) assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=torch_ground_truth, + torch_ground_truth=self.torch_ground_truth, ) def test_ort_generation_after_kv_cache_has_been_filled(self, setup): @@ -243,22 +254,13 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, engine_type="onnxruntime", ) pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output = self.run_pipeline(pipeline) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " @@ -268,8 +270,8 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - torch_ground_truth=torch_ground_truth, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 + torch_ground_truth=self.torch_ground_truth, + logits_threshold=self.logits_threshold, ) def test_deepsparse_single_token_prefill(self, setup): @@ -279,27 +281,21 @@ def test_deepsparse_single_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally or internally - _, _, torch_ground_truth = setup pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, prompt_sequence_length=1, - internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output = self.run_pipeline(pipeline) - assert output.total_num_processed_tokens[0] < self.sequence_length + assert output.total_num_processed_tokens[0] < self.sequence_length, ( + "The total number of processed tokens must be smaller than the " + "sequence length" + ) self._test_output( output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_multi_token_prefill(self, setup): @@ -307,54 +303,33 @@ def test_deepsparse_multi_token_prefill(self, setup): # following scenario: # 1. Prompt preprocessing is performed by multi-token engine # 2. The KV Cache is never filled up - # 3. KV Cache managed externally or internally + # 3. KV Cache managed internally or externally - _, _, torch_ground_truth = setup - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, - ) + pipeline = self.get_pipeline() pipeline._debug = True - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output = self.run_pipeline(pipeline) assert output.total_num_processed_tokens[0] < self.sequence_length + self._test_output( output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, + torch_ground_truth=self.torch_ground_truth, + # disable kv cache validation if using internal kv cache + run_kv_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): - # Test the pipeline that uses deepsparse engine. The test covers the + # Test the deepsparse that uses deepsparse engine. The test covers the # following scenario: # 1. Prompt preprocessing is performed by multi-token engine # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed externally or internally + # 3. KV Cache managed internally or externally - _, _, torch_ground_truth = setup pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, sequence_length=self.sequence_length_short, - prompt_sequence_length=self.prompt_sequence_length, - internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output = self.run_pipeline(pipeline) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " @@ -364,47 +339,39 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, - max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 + torch_ground_truth=self.torch_ground_truth, + logits_threshold=self.logits_threshold, + run_kv_cache_validation=not self.internal_kv_cache, ) + @helper_test + def test_freeze_first_position(self, setup): + # Test whether we should be "freezing" the first token after + # the kv cache is full + pipeline = self.get_pipeline() + assert prepends_bos_token(pipeline.tokenizer) == self.has_bos_token + + @helper_test def test_run_same_prompt_multiple_times(self, setup): # Test the scenario, where the same prompt is run multiple times # Every run should produce the same output pipeline = self.get_pipeline() - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - - output_1 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) - - output_2 = pipeline( - sequences=self.prompt, include_prompt_logits=True, generation_config=config - ) + output_1 = self.run_pipeline(pipeline) + output_2 = self.run_pipeline(pipeline) assert output_1.generations[0].text == output_2.generations[0].text assert numpy.allclose( output_1.generations[0].score, output_2.generations[0].score, - atol=_PRECISION, + atol=self.precision, ) + @helper_test def test_run_multiple_prompts_in_parallel(self, setup): # Test the scenario, where multiple prompts are run in parallel # Same two prompts should produce the same output pipeline = self.get_pipeline() - - config = GenerationConfig( - output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 - ) - output = pipeline( - sequences=[self.prompt, self.prompt], - generation_config=config, - include_prompt_logits=True, - ) + output = self.run_pipeline(pipeline, sequences=[self.prompt, self.prompt]) logits_0 = output.generations[0].score sequence_0 = output.generations[0].text @@ -412,200 +379,68 @@ def test_run_multiple_prompts_in_parallel(self, setup): logits_1 = output.generations[1].score sequence_1 = output.generations[1].text - assert numpy.allclose(logits_0, logits_1, atol=_PRECISION) + assert numpy.allclose(logits_0, logits_1, atol=self.precision) assert sequence_0 == sequence_1 + @helper_test def test_num_generated_predictions(self, setup): # Test the scenario, where multiple predictions are generated # from the same prompt pipeline = self.get_pipeline() - config = GenerationConfig( - num_return_sequences=2, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, + output_sequences = self.run_pipeline( + pipeline, sequences=[self.prompt], num_return_sequences=2 ) - - output_sequences = pipeline(sequences=[self.prompt], generation_config=config) assert len(output_sequences.generations) == 1 assert len(output_sequences.generations[0]) == 2 - output_sequences = pipeline( - sequences=[self.prompt, self.prompt], generation_config=config + output_sequences = self.run_pipeline( + pipeline, sequences=[self.prompt, self.prompt], num_return_sequences=2 ) assert len(output_sequences.generations) == 2 for generation in output_sequences.generations: assert len(generation) == 2 + @helper_test def test_token_generation_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=False, - ) - pipeline = self.get_pipeline(**pipeline_kwargs) - inference = pipeline(sequences=["hello?"], generation_config=config) + pipeline = self.get_pipeline() + inference = self.run_pipeline(pipeline, num_return_sequences=3, do_sample=False) generations = inference.generations + # Output should be the same from one another text_outputs = [x.text for x in generations[0]] assert len(set(text_outputs)) == 1 + @helper_test def test_token_generation_non_deterministic(self, setup): - pipeline_kwargs = { - "task": "text_generation", - "model_path": self.model_stub, - } - pipeline = self.get_pipeline(**pipeline_kwargs) - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=3, - do_sample=True, - ) - inference = pipeline(sequences=["hello?"], generation_config=config) + pipeline = self.get_pipeline() + inference = self.run_pipeline(pipeline, num_return_sequences=3, do_sample=True) generations = inference.generations - # Output should be the same from one another + # Output should be different from one another text_outputs = [x.text for x in generations[0]] assert len(set(text_outputs)) == 3 - def test_run_with_same_session_ids(self, setup): - # Test the scenario where the same session ids are used for multiple - # inference runs. There are two conditions that must be fulfilled: - # 1. The information regarding the prompt does not leak between sessions - # 2. Running two prompts one after another is identical to running - # a composition of those prompts i.e. - # generated_text = pipeline(prompt_1) - # generated_text_2 = pipeline(prompt_2) - # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) - - if self.pipeline_type not in ["chatbot", "chat"]: - pytest.skip("This test is only applicable to chatbot pipeline") - - prompt_1 = "This prompt is used for testing purposes. To this to make sure that" - prompt_2 = "still this prompt should not" - num_generated_tokens = 32 - - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=False, - ) - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=True, - ) - - def _test_run_with_same_session_ids( - self, - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill, - ): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - prompt_sequence_length=self.prompt_sequence_length - if multi_token_prefill - else 1, - force_max_tokens=True, - internal_kv_cache=self.internal_kv_cache, - ) - - # make sure information does not leak between sessions - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_1", - session_id_2="test_2", - ) - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_3", - session_id_2="test_4", - ) - - @staticmethod - def _test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1, - session_id_2, - ): + @helper_test + def test_streaming_mode_returns_generator(self, setup): + pipeline = self.get_pipeline(prompt_sequence_length=1) + response_generator = self.run_pipeline(pipeline, streaming=True) - tokenizer = pipeline.tokenizer - config = GenerationConfig( - output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 - ) + assert inspect.isgenerator( + response_generator + ), "Pipeline should return a generator in streaming mode" - # make sure that running two prompts one after another - # is identical to running a composition of those prompts - out_1_ = pipeline( - sequences=prompt_1, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - prompt_1_ = out_1_.generations[0].text - out_1 = pipeline( - sequences=prompt_2, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ - "past_key_values.0.key" - ] - - prompt_composition = tokenizer.decode( - tokenizer(prompt_1).input_ids - + tokenizer(prompt_1_).input_ids - + tokenizer(prompt_2).input_ids, - skip_special_tokens=True, - ) - out_2 = pipeline( - sequences=prompt_composition, - session_ids=session_id_2, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ - "past_key_values.0.key" - ] - if cache_state_1.shape[0]: - # if cache state is not empty, i.e. we are managing kv cache - # externally, make sure that the cache state is the same - numpy.allclose(cache_state_1, cache_state_2, atol=_PRECISION) - assert out_1.generations[0].text == out_2.generations[0].text + assert all( + isinstance(response, pipeline.output_schema) + for response in response_generator + ), "Pipeline should return a generator of output_schema \ + objects in streaming mode" def _test_output( self, - output: "TextGenerationOutput", # noqa F821 + output: TextGenerationOutput, torch_ground_truth: Tuple[numpy.ndarray, ...], - max_logits_difference_threshold: Optional[float] = None, - run_cache_validation: bool = True, + logits_threshold: Optional[float] = None, + run_kv_cache_validation: bool = True, ): ( @@ -615,40 +450,42 @@ def _test_output( generated_text, ) = torch_ground_truth - # concatenate target prompt_logits and generated_logits and check + # concatenate target prompt_logits and generated_logits target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) + # get the logits of the generated sequence score = output.generations[0].score - if max_logits_difference_threshold: + if logits_threshold: # if comparing the output from the model where # the kv cache has been filled, we expect the # maximum absolute difference between the logits # to be less than the threshold # (the threshold is established by running the # ONNX model in ONNXRuntime) - assert abs(score - target_logits[0]).max() < max_logits_difference_threshold + target_logits = target_logits[0] + if target_logits.shape[0] < score.shape[0]: + score = score[: target_logits.shape[0], :] + assert abs(score - target_logits).max() < logits_threshold else: # otherwise, we expect the logits to be exactly the same # as the target logits; the generated sequence should - # also be the same as the target sequence, and finally - # (if applicable) the kv cache should be the same as the - # target kv cache - - assert numpy.allclose(score, target_logits[0], atol=_PRECISION) + # also be the same as the target sequence + assert numpy.allclose(score, target_logits[0], atol=self.precision) assert self.prompt + output.generations[0].text == generated_text - if run_cache_validation: - # extract numpy arrays from cached_inputs - kv_cache_array = list(output.kv_cache_state[0].values()) + if hasattr(output, "kv_cache_state") and run_kv_cache_validation: + # (if applicable) the kv cache should be the same as the + # target kv cache + expected_cache = list(output.kv_cache_state[0].values()) total_num_processed_tokens = output.total_num_processed_tokens[0] self._test_kv_cache_state( - expected_cache=kv_cache_array, - target_cache=torch_ground_truth[2], + expected_cache=expected_cache, + target_cache=prompt_kv_cache, total_num_processed_tokens=total_num_processed_tokens, ) - @staticmethod def _test_kv_cache_state( + self, expected_cache: List[numpy.ndarray], target_cache: List[numpy.ndarray], total_num_processed_tokens: int, @@ -663,25 +500,5 @@ def _test_kv_cache_state( # as target_cache only pertains to prompt cache entries, we need to # compare only the prompt cache entries in x with y assert numpy.allclose( - x[:, :, -start_index:-end_index, :], y, atol=_PRECISION + x[:, :, -start_index:-end_index, :], y, atol=self.precision ) - - def test_streaming_mode_returns_generator(self, setup): - pipeline = self.get_pipeline( - task=self.pipeline_type, - model_path=self.model_stub, - sequence_length=self.sequence_length, - prompt_sequence_length=1, - ) - inputs = dict(prompt=self.prompt, streaming=True) - response_generator = pipeline(**inputs) - - assert inspect.isgenerator( - response_generator - ), "Pipeline should return a generator in streaming mode" - - assert all( - isinstance(response, pipeline.output_schema) - for response in response_generator - ), "Pipeline should return a generator of output_schema \ - objects in streaming mode" From e9f7e880185ca8547bdc1e363694d1e1a251b080 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 18 Oct 2023 06:35:50 +0000 Subject: [PATCH 2/5] initial commit --- .../transformers/engines/nl_decoder_engine.py | 35 ++++++++++++------- .../transformers/pipelines/text_generation.py | 5 ++- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 9c93616053..dabf520ff3 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -154,18 +154,29 @@ def run( :return: The output of the engine """ - if bool(kv_cache.engine_internal_cache): - # conventionally, before dispatching - # inputs to the engine, we validate them - # if val_inp=True. However, in this case - # we want to pass the empty kv cache inputs - # (batch_size=0) to the engine. Therefore, - # we skip the validation - return self.engine._eng_net.execute_list_out( - inputs, kv_cache.engine_internal_cache - ) - # run the engine without the LIB.kv_cache object - return self.engine.run(inputs, val_inp) + if kv_cache is not None: + # run the engine assuming kv cache support + if bool(kv_cache.engine_internal_cache): + # run the engine assuming internal kv cache + # management. In this case the LIB.kv_cache + # class object will be passed to the engine + # call as well + # conventionally, before dispatching + # inputs to the engine, we validate them + # if val_inp=True. However, in this case + # we want to pass the empty kv cache inputs + # (batch_size=0) to the engine. Therefore, + # we skip the validation + return self.engine._eng_net.execute_list_out( + inputs, kv_cache.engine_internal_cache + ) + else: + # run the engine assuming external kv cache + # management. + return self.engine.run(inputs, val_inp, kv_cache) + else: + # run the engine without the kv cache support + return self.engine.run(inputs, val_inp) def __call__( self, diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 06c1d9750c..9577b482f3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -683,7 +683,10 @@ def engine_forward( ) for prompt_logit in prompt_logits: token_generator.generate(prompt_logit) - return numpy.array([self.tokens]), prompt_logits + yield numpy.array([token_generator.tokens]), prompt_logits, [ + FinishReason.LENGTH + ] + return else: # run the prompt through From 0431a7fb57211d68fd16f544784c08010578d131 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 18 Oct 2023 06:38:21 +0000 Subject: [PATCH 3/5] revert changes --- .../transformers/pipelines/helpers.py | 49 +- .../transformers/pipelines/test_chat.py | 173 +---- .../pipelines/test_text_generation.py | 645 +++++++++++------- 3 files changed, 440 insertions(+), 427 deletions(-) diff --git a/tests/deepsparse/transformers/pipelines/helpers.py b/tests/deepsparse/transformers/pipelines/helpers.py index e22f5d23ff..0bb962a8e3 100644 --- a/tests/deepsparse/transformers/pipelines/helpers.py +++ b/tests/deepsparse/transformers/pipelines/helpers.py @@ -12,16 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import os -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Tuple import numpy -import yaml from transformers import AutoModelForCausalLM, AutoTokenizer -import pytest - class TorchGroundTruthSource: """ @@ -87,45 +82,3 @@ def _create_tokenizer(model_name): tokenizer.pad_token = tokenizer.eos_token return tokenizer - - -def parse_params(config_path: str) -> Tuple[Optional[Dict], Optional[str]]: - # parses the config file provided - assert os.path.isfile(config_path), f"config_path {config_path} is not a file" - # reads the yaml file - with open(config_path, "r") as f: - config = yaml.safe_load(f) - - cadence = os.environ.get("CADENCE", "commit") - if cadence == config["cadence"]: - return config, None - return None, "Skipping test for cadence: {}".format(config["cadence"]) - - -def validate_cache_management_type( - internal_kv_cache, cache_management_type: Union[str, List[str]] -) -> bool: - if internal_kv_cache and "internal" not in cache_management_type: - pytest.skip( - "The tests for running the pipeline with " - "internal kv cache management are disabled." - ) - if not internal_kv_cache and "external" not in cache_management_type: - pytest.skip( - "The tests for running the pipeline with " - "external kv cache management are disabled." - ) - return internal_kv_cache - - -def helper_test(test_method): - @functools.wraps(test_method) - def wrapper(self, setup): - if not self.run_helper_tests: - raise pytest.skip( - "Skipping the helper test. Set run_helper_tests to True to run it." - ) - - return test_method(self, setup) - - return wrapper diff --git a/tests/deepsparse/transformers/pipelines/test_chat.py b/tests/deepsparse/transformers/pipelines/test_chat.py index 4da1c66726..2a6b5d1ebf 100644 --- a/tests/deepsparse/transformers/pipelines/test_chat.py +++ b/tests/deepsparse/transformers/pipelines/test_chat.py @@ -12,157 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from transformers import GenerationConfig - import pytest -from tests.deepsparse.transformers.pipelines.helpers import helper_test -from tests.deepsparse.transformers.pipelines.test_text_generation import ( - TestTextGenerationPipeline, +from deepsparse import Pipeline + + +@pytest.mark.parametrize( + "pipeline_kwargs", + [ + dict( + model_path="zoo:nlg/text_generation/codegen_mono-350m/pytorch/" + "huggingface/bigpython_bigquery_thepile/base-none", + engine_type="onnxruntime", + ), + ], ) +@pytest.mark.skip(reason="too heavy for now to run in gha") +def test_chat_pipeline_session_manager(pipeline_kwargs): + chat_pipeline = Pipeline.create(task="chat", **pipeline_kwargs) - -@pytest.fixture -def config(request): - return request.param - - -class TestChatPipeline(TestTextGenerationPipeline): - @pytest.fixture - def pipeline_type(self): - return "chat" - - @helper_test - def test_chat_pipeline_session_manager(self, setup): - pipeline = self.get_pipeline() - - with pipeline.session(): - output_1 = pipeline( - prompt="first", generation_config=dict(max_new_tokens=1) - ) - output_2 = pipeline( - prompt="second", generation_config=dict(max_new_tokens=1) - ) - # assert inferences in the same context share a session id - assert output_1.session_ids == output_2.session_ids - - # test that follow-up inference has a different session id - output_3 = pipeline(prompt="third", generation_config=dict(max_new_tokens=1)) - assert output_3.session_ids != output_1.session_ids - - @helper_test - def test_run_with_same_session_ids(self, setup): - # Test the scenario where the same session ids are used for multiple - # inference runs. There are two conditions that must be fulfilled: - # 1. The information regarding the prompt does not leak between sessions - # 2. Running two prompts one after another is identical to running - # a composition of those prompts i.e. - # generated_text = pipeline(prompt_1) - # generated_text_2 = pipeline(prompt_2) - # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) - - prompt_1 = "This prompt is used for testing purposes. To this to make sure that" - prompt_2 = "still this prompt should not" - num_generated_tokens = 32 - - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=False, - ) - self._test_run_with_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill=True, - ) - - def _test_run_with_same_session_ids( - self, - prompt_1, - prompt_2, - num_generated_tokens, - multi_token_prefill, - ): - pipeline = self.get_pipeline( - prompt_sequence_length=self.prompt_sequence_length - if multi_token_prefill - else 1, - ) - - # make sure information does not leak between sessions - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_1", - session_id_2="test_2", - ) - - self._test_composition_same_session_ids( - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1="test_3", - session_id_2="test_4", - ) - - def _test_composition_same_session_ids( - self, - prompt_1, - prompt_2, - num_generated_tokens, - pipeline, - session_id_1, - session_id_2, - ): - - tokenizer = pipeline.tokenizer - config = GenerationConfig( - output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 - ) - - # make sure that running two prompts one after another - # is identical to running a composition of those prompts - out_1_ = pipeline( - sequences=prompt_1, - force_max_tokens=True, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, + with chat_pipeline.session(): + output_1 = chat_pipeline( + prompt="first", generation_config=dict(max_new_tokens=1) ) - prompt_1_ = out_1_.generations[0].text - out_1 = pipeline( - sequences=prompt_2, - force_max_tokens=True, - session_ids=session_id_1, - generation_config=config, - include_prompt_logits=True, + output_2 = chat_pipeline( + prompt="second", generation_config=dict(max_new_tokens=1) ) - cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ - "past_key_values.0.key" - ] + # assert inferences in the same context share a session id + assert output_1.session_ids == output_2.session_ids - prompt_composition = tokenizer.decode( - tokenizer(prompt_1).input_ids - + tokenizer(prompt_1_).input_ids - + tokenizer(prompt_2).input_ids, - skip_special_tokens=True, - ) - out_2 = pipeline( - sequences=prompt_composition, - session_ids=session_id_2, - generation_config=config, - include_prompt_logits=True, - ) - cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ - "past_key_values.0.key" - ] - if cache_state_1.shape[0]: - # if cache state is not empty, i.e. we are managing kv cache - # externally, make sure that the cache state is the same - np.allclose(cache_state_1, cache_state_2, atol=self.precision) - assert out_1.generations[0].text == out_2.generations[0].text + # test that follow-up inference has a different session id + output_3 = chat_pipeline(prompt="third", generation_config=dict(max_new_tokens=1)) + assert output_3.session_ids != output_1.session_ids diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index b8a3e4ff4b..5298c2f1dd 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -11,35 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -A sample config file requires the following arguments: - model_path: The path to the model to be tested - (sparsezoo stub/hf model path/local_path) - model_name: The name of the hugging face model - (to generate ground truth info) - pipeline_type: The type of the pipeline to be tested - (e.g. text-generation) - num_tokens_generate: The number of tokens to generate - prompt: The prompt to use for testing - has_bos_token: Whether the model has a bos token - logits_threshold: The treshold for the max difference between the - actual and the expected logits in the situations - where they will not be able to match the ground - truth (e.g. when the DeepSparse pipeline is - running after the KV cache has been filled up). - This value is established - empirically for each combination of - prompt/pipeline/num_generated tokens. - precision: The precision for the logits/kv_cache entries comparison - cache_management_type: The type of cache management to be tested. - The available options are: "internal" and "external". - run_helper_tests: Whether to run the helper test for the pipeline. Helper tests - check functionalities of the pipeline that are not directly - on the hot path. - cadence: The cadence of the tests. The available options are: - "nightly" and "commit". By default, only the tests that have cadence - "commit" will be run in GHA. -""" + import inspect from typing import List, Optional, Tuple @@ -48,149 +20,150 @@ import pytest from deepsparse import Pipeline -from deepsparse.transformers.pipelines.text_generation import TextGenerationOutput from deepsparse.transformers.utils.helpers import prepends_bos_token -from tests.deepsparse.transformers.pipelines.helpers import ( - TorchGroundTruthSource, - helper_test, - parse_params, - validate_cache_management_type, -) +from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource -# the user can specify the config file to be used for the tests -# TODO: add more configs -# TODO: add explanation -AVAILABLE_CONFIGS = [ - "tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml", - # "tests/deepsparse/transformers/pipelines/configs/text_generation_opt.yaml", - # "tests/deepsparse/transformers/pipelines/configs/text_generation_codegen.yaml", -] +_PRECISION = 1e-3 +NATURAL_LANGUAGE_PROMPT = """ +Didn't know what time it was, the lights were low +I leaned back on my radio +Some cat was layin' down some rock 'n' roll +"Lotta soul," he said +Then the loud sound did seem to fade +Came back like a slow voice on a wave of phase +That weren't no DJ, that was hazy cosmic jive +""" -@pytest.fixture -def config(request): - return request.param +CODE_LANGUAGE_PROMPT = """ +def Fibonacci(n): + # Check if input is 0 then it will + # print incorrect input + if n < 0: + print("Incorrect input") + # Check if n is 0 + # then it will return 0 + elif n == 0: + return 0 +""" -@pytest.mark.parametrize("config", AVAILABLE_CONFIGS, indirect=["config"]) @pytest.mark.parametrize( "internal_kv_cache", - [True, False], + [ + True, + False, + ], +) +@pytest.mark.parametrize( + "pipeline_type", + ["text_generation", "chat"], +) +@pytest.mark.parametrize( + "model_stub, " + "model_name, " + "uses_bos_token, " + "prompt, " + "logits_max_diff_kv_cache_has_been_filled", + [ + ( + "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" + "huggingface/bigpython_bigquery_thepile/base-none", + "salesforce/codegen-350m-mono", + False, + CODE_LANGUAGE_PROMPT, + 13, + ), + ( + "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" + "opt_pretrain/base-none", + "facebook/opt-1.3b", + True, + NATURAL_LANGUAGE_PROMPT, + 3.9, + ), + ], + scope="class", ) +@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of the text generation pipeline. """ - @pytest.fixture - def pipeline_type(self): - return "text-generation" - - def get_pipeline(self, **kwargs) -> Pipeline: - """ - If no kwargs provided, returns the cached "default" - pipeline that is used for most of the tests. - Otherwise, returns a pipeline with the given kwargs - (the default pipeline kwargs are updated with the - user-provided kwargs) - - :param kwargs: the optional kwargs to be used to - create the pipeline (if not provided, the cached - "default" pipeline is returned) - :return: the appropriate pipeline - """ + def get_pipeline(self, **kwargs): if not kwargs: - if self.default_pipeline is None: - self.default_pipeline = Pipeline.create(**self.default_pipeline_kwargs) - return self.default_pipeline - - # return a pipeline with the updated default kwargs - updated_kwargs = self.default_pipeline_kwargs.copy() - updated_kwargs.update(kwargs) - return Pipeline.create(**updated_kwargs) - - def run_pipeline(self, pipeline: Pipeline, **kwargs) -> TextGenerationOutput: - """ - Run the pipeline and return the output - - :param pipeline: the pipeline to be run - :param kwargs: the optional kwargs to be used to - run the pipeline - :return: the pipeline output - """ - sequences = kwargs.get("sequences", self.prompt) - num_return_sequences = kwargs.get("num_return_sequences", 1) - do_sample = kwargs.get("do_sample", False) - streaming = kwargs.get("streaming", False) - - config = GenerationConfig( - output_scores=True, - max_length=self.num_tokens_generate, - top_k=0, - top_p=0.0, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - ) - return pipeline( - sequences=sequences, - force_max_tokens=True, - include_prompt_logits=True, - generation_config=config, - streaming=streaming, - ) + # return the default pipeline + if self.default_pipeline: + return self.default_pipeline + else: + self.default_pipeline = Pipeline.create( + task=self.pipeline_type, + model_path=self.model_stub, + internal_kv_cache=self.internal_kv_cache, + prompt_sequence_length=self.prompt_sequence_length, + sequence_length=self.sequence_length, + ) + return self.default_pipeline + # return a pipeline with the given kwargs + return Pipeline.create(**kwargs) @pytest.fixture - def setup(self, config, internal_kv_cache, pipeline_type): - params_dict, skip_reason = parse_params(config) - if params_dict is None: - # skip the test if the config file is not available - pytest.skip(skip_reason) - # set the params_dict as the class attributes - for key, value in params_dict.items(): - setattr(self, key, value) - # check whether the internal kv cache is supported for testing - # (skip if not supported) - self.internal_kv_cache: bool = validate_cache_management_type( - internal_kv_cache, self.cache_management_type - ) - # check whether the pipeline_type is supported for testing - # (skip if not supported) - if pipeline_type not in self.pipeline_type: - pytest.skip( - f"Pipeline type: {self.pipeline_type} " - f"does not match the current type: {pipeline_type}" - ) - + def setup( + self, + model_stub, + model_name, + uses_bos_token, + prompt, + logits_max_diff_kv_cache_has_been_filled, + internal_kv_cache, + pipeline_type, + ): + self.num_tokens_generate = 216 + self.model_stub = model_stub + self.prompt = prompt + self.pipeline_type = pipeline_type # create torch ground source torch_source = TorchGroundTruthSource( - num_tokens_to_generate=self.num_tokens_generate, model_name=self.model_name + num_tokens_to_generate=self.num_tokens_generate, model_name=model_name ) - # create torch ground truth - self.torch_ground_truth = torch_source(self.prompt) - prompt_length = self.torch_ground_truth[1].shape[1] + torch_ground_truth = torch_source(self.prompt) + + # prompt length is expressed in number of prompt tokens + prompt_length = torch_ground_truth[1].shape[1] # sequence_length that assures that the KV cache will not be filled up self.sequence_length = 2 * prompt_length + self.num_tokens_generate # sequence_length that assures that the KV cache will be filled up self.sequence_length_short = self.num_tokens_generate - # prompt_sequence_length used for the multi-token prefill scenario - self.prompt_sequence_length = prompt_length // 4 + # prompt_sequence_length used for the multitoken prefill scenario + self.prompt_sequence_length = prompt_length // 2 + + # the maximum threshold for the difference between the logits + # when running a scenario where KV Cache buffer has been filled + self.logits_max_diff_kv_cache_has_been_filled = ( + logits_max_diff_kv_cache_has_been_filled + ) + self.internal_kv_cache = internal_kv_cache + + self.default_pipeline = None + assert self.prompt_sequence_length < prompt_length, ( "The prompt processing sequence length " "must be smaller than the prompt length" ) - self.default_pipeline_kwargs = dict( - task=pipeline_type, - model_path=self.model_path, - internal_kv_cache=self.internal_kv_cache, - prompt_sequence_length=self.prompt_sequence_length, - sequence_length=self.sequence_length, - ) - self.default_pipeline = None + yield model_name, uses_bos_token, torch_ground_truth + + def test_freeze_first_position(self, setup): + # Test whether we should be "freezing" the first token after + # the kv cache is full + _, uses_bos_token, _ = setup + pipeline = self.get_pipeline() + assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -203,21 +176,27 @@ def test_ort_single_token_prefill(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - + _, _, torch_ground_truth = setup pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + sequence_length=self.sequence_length, prompt_sequence_length=1, engine_type="onnxruntime", ) pipeline._debug = True - output = self.run_pipeline(pipeline) - assert output.total_num_processed_tokens[0] < self.sequence_length, ( - "The total number of processed tokens must be smaller than the " - "sequence length" + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, + torch_ground_truth=torch_ground_truth, ) def test_ort_multi_token_prefill(self, setup): @@ -231,16 +210,26 @@ def test_ort_multi_token_prefill(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) + _, _, torch_ground_truth = setup pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, engine_type="onnxruntime", ) pipeline._debug = True - output = self.run_pipeline(pipeline) + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, + torch_ground_truth=torch_ground_truth, ) def test_ort_generation_after_kv_cache_has_been_filled(self, setup): @@ -254,13 +243,22 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): pytest.skip( "Cannot run ORT pipeline with the internal deepsparse cache enabled." ) - + _, _, torch_ground_truth = setup pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, sequence_length=self.sequence_length_short, + prompt_sequence_length=self.prompt_sequence_length, engine_type="onnxruntime", ) pipeline._debug = True - output = self.run_pipeline(pipeline) + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " @@ -270,8 +268,8 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, - logits_threshold=self.logits_threshold, + torch_ground_truth=torch_ground_truth, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) def test_deepsparse_single_token_prefill(self, setup): @@ -281,21 +279,27 @@ def test_deepsparse_single_token_prefill(self, setup): # 2. The KV Cache is never filled up # 3. KV Cache managed externally or internally + _, _, torch_ground_truth = setup pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + sequence_length=self.sequence_length, prompt_sequence_length=1, + internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True - output = self.run_pipeline(pipeline) - - assert output.total_num_processed_tokens[0] < self.sequence_length, ( - "The total number of processed tokens must be smaller than the " - "sequence length" + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config ) + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, - # disable kv cache validation if using internal kv cache - run_kv_cache_validation=not self.internal_kv_cache, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_multi_token_prefill(self, setup): @@ -303,33 +307,54 @@ def test_deepsparse_multi_token_prefill(self, setup): # following scenario: # 1. Prompt preprocessing is performed by multi-token engine # 2. The KV Cache is never filled up - # 3. KV Cache managed internally or externally + # 3. KV Cache managed externally or internally - pipeline = self.get_pipeline() + _, _, torch_ground_truth = setup + pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=self.prompt_sequence_length, + internal_kv_cache=self.internal_kv_cache, + ) pipeline._debug = True - output = self.run_pipeline(pipeline) - assert output.total_num_processed_tokens[0] < self.sequence_length + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, - # disable kv cache validation if using internal kv cache - run_kv_cache_validation=not self.internal_kv_cache, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): - # Test the deepsparse that uses deepsparse engine. The test covers the + # Test the pipeline that uses deepsparse engine. The test covers the # following scenario: # 1. Prompt preprocessing is performed by multi-token engine # 2. The KV Cache is filled up (old entries are removed) - # 3. KV Cache managed internally or externally + # 3. KV Cache managed externally or internally + _, _, torch_ground_truth = setup pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, sequence_length=self.sequence_length_short, + prompt_sequence_length=self.prompt_sequence_length, + internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True - output = self.run_pipeline(pipeline) + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " @@ -339,39 +364,47 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - torch_ground_truth=self.torch_ground_truth, - logits_threshold=self.logits_threshold, - run_kv_cache_validation=not self.internal_kv_cache, + torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, + max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) - @helper_test - def test_freeze_first_position(self, setup): - # Test whether we should be "freezing" the first token after - # the kv cache is full - pipeline = self.get_pipeline() - assert prepends_bos_token(pipeline.tokenizer) == self.has_bos_token - - @helper_test def test_run_same_prompt_multiple_times(self, setup): # Test the scenario, where the same prompt is run multiple times # Every run should produce the same output pipeline = self.get_pipeline() - output_1 = self.run_pipeline(pipeline) - output_2 = self.run_pipeline(pipeline) + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + + output_1 = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) + + output_2 = pipeline( + sequences=self.prompt, include_prompt_logits=True, generation_config=config + ) assert output_1.generations[0].text == output_2.generations[0].text assert numpy.allclose( output_1.generations[0].score, output_2.generations[0].score, - atol=self.precision, + atol=_PRECISION, ) - @helper_test def test_run_multiple_prompts_in_parallel(self, setup): # Test the scenario, where multiple prompts are run in parallel # Same two prompts should produce the same output pipeline = self.get_pipeline() - output = self.run_pipeline(pipeline, sequences=[self.prompt, self.prompt]) + + config = GenerationConfig( + output_scores=True, max_length=self.num_tokens_generate, top_k=0, top_p=0.0 + ) + output = pipeline( + sequences=[self.prompt, self.prompt], + generation_config=config, + include_prompt_logits=True, + ) logits_0 = output.generations[0].score sequence_0 = output.generations[0].text @@ -379,68 +412,200 @@ def test_run_multiple_prompts_in_parallel(self, setup): logits_1 = output.generations[1].score sequence_1 = output.generations[1].text - assert numpy.allclose(logits_0, logits_1, atol=self.precision) + assert numpy.allclose(logits_0, logits_1, atol=_PRECISION) assert sequence_0 == sequence_1 - @helper_test def test_num_generated_predictions(self, setup): # Test the scenario, where multiple predictions are generated # from the same prompt pipeline = self.get_pipeline() - output_sequences = self.run_pipeline( - pipeline, sequences=[self.prompt], num_return_sequences=2 + config = GenerationConfig( + num_return_sequences=2, + max_length=self.num_tokens_generate, + top_k=0, + top_p=0.0, ) + + output_sequences = pipeline(sequences=[self.prompt], generation_config=config) assert len(output_sequences.generations) == 1 assert len(output_sequences.generations[0]) == 2 - output_sequences = self.run_pipeline( - pipeline, sequences=[self.prompt, self.prompt], num_return_sequences=2 + output_sequences = pipeline( + sequences=[self.prompt, self.prompt], generation_config=config ) assert len(output_sequences.generations) == 2 for generation in output_sequences.generations: assert len(generation) == 2 - @helper_test def test_token_generation_deterministic(self, setup): - pipeline = self.get_pipeline() - inference = self.run_pipeline(pipeline, num_return_sequences=3, do_sample=False) + pipeline_kwargs = { + "task": "text_generation", + "model_path": self.model_stub, + } + config = GenerationConfig( + output_scores=True, + max_length=self.num_tokens_generate, + top_k=0, + top_p=0.0, + num_return_sequences=3, + do_sample=False, + ) + pipeline = self.get_pipeline(**pipeline_kwargs) + inference = pipeline(sequences=["hello?"], generation_config=config) generations = inference.generations - # Output should be the same from one another text_outputs = [x.text for x in generations[0]] assert len(set(text_outputs)) == 1 - @helper_test def test_token_generation_non_deterministic(self, setup): - pipeline = self.get_pipeline() - inference = self.run_pipeline(pipeline, num_return_sequences=3, do_sample=True) + pipeline_kwargs = { + "task": "text_generation", + "model_path": self.model_stub, + } + pipeline = self.get_pipeline(**pipeline_kwargs) + config = GenerationConfig( + output_scores=True, + max_length=self.num_tokens_generate, + top_k=0, + top_p=0.0, + num_return_sequences=3, + do_sample=True, + ) + inference = pipeline(sequences=["hello?"], generation_config=config) generations = inference.generations - # Output should be different from one another + # Output should be the same from one another text_outputs = [x.text for x in generations[0]] assert len(set(text_outputs)) == 3 - @helper_test - def test_streaming_mode_returns_generator(self, setup): - pipeline = self.get_pipeline(prompt_sequence_length=1) - response_generator = self.run_pipeline(pipeline, streaming=True) + def test_run_with_same_session_ids(self, setup): + # Test the scenario where the same session ids are used for multiple + # inference runs. There are two conditions that must be fulfilled: + # 1. The information regarding the prompt does not leak between sessions + # 2. Running two prompts one after another is identical to running + # a composition of those prompts i.e. + # generated_text = pipeline(prompt_1) + # generated_text_2 = pipeline(prompt_2) + # generated_text_2 == pipeline(prompt_1 + generated_text + prompt_2) + + if self.pipeline_type not in ["chatbot", "chat"]: + pytest.skip("This test is only applicable to chatbot pipeline") + + prompt_1 = "This prompt is used for testing purposes. To this to make sure that" + prompt_2 = "still this prompt should not" + num_generated_tokens = 32 + + self._test_run_with_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill=False, + ) + self._test_run_with_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill=True, + ) - assert inspect.isgenerator( - response_generator - ), "Pipeline should return a generator in streaming mode" + def _test_run_with_same_session_ids( + self, + prompt_1, + prompt_2, + num_generated_tokens, + multi_token_prefill, + ): + pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + prompt_sequence_length=self.prompt_sequence_length + if multi_token_prefill + else 1, + force_max_tokens=True, + internal_kv_cache=self.internal_kv_cache, + ) - assert all( - isinstance(response, pipeline.output_schema) - for response in response_generator - ), "Pipeline should return a generator of output_schema \ - objects in streaming mode" + # make sure information does not leak between sessions + + self._test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1="test_1", + session_id_2="test_2", + ) + + self._test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1="test_3", + session_id_2="test_4", + ) + + @staticmethod + def _test_composition_same_session_ids( + prompt_1, + prompt_2, + num_generated_tokens, + pipeline, + session_id_1, + session_id_2, + ): + + tokenizer = pipeline.tokenizer + config = GenerationConfig( + output_scores=True, max_length=num_generated_tokens, top_k=0, top_p=0.0 + ) + + # make sure that running two prompts one after another + # is identical to running a composition of those prompts + out_1_ = pipeline( + sequences=prompt_1, + session_ids=session_id_1, + generation_config=config, + include_prompt_logits=True, + ) + prompt_1_ = out_1_.generations[0].text + out_1 = pipeline( + sequences=prompt_2, + session_ids=session_id_1, + generation_config=config, + include_prompt_logits=True, + ) + cache_state_1 = pipeline.storage_kv_cache.get(session_id_1).cached_inputs[ + "past_key_values.0.key" + ] + + prompt_composition = tokenizer.decode( + tokenizer(prompt_1).input_ids + + tokenizer(prompt_1_).input_ids + + tokenizer(prompt_2).input_ids, + skip_special_tokens=True, + ) + out_2 = pipeline( + sequences=prompt_composition, + session_ids=session_id_2, + generation_config=config, + include_prompt_logits=True, + ) + cache_state_2 = pipeline.storage_kv_cache.get(session_id_2).cached_inputs[ + "past_key_values.0.key" + ] + if cache_state_1.shape[0]: + # if cache state is not empty, i.e. we are managing kv cache + # externally, make sure that the cache state is the same + numpy.allclose(cache_state_1, cache_state_2, atol=_PRECISION) + assert out_1.generations[0].text == out_2.generations[0].text def _test_output( self, - output: TextGenerationOutput, + output: "TextGenerationOutput", # noqa F821 torch_ground_truth: Tuple[numpy.ndarray, ...], - logits_threshold: Optional[float] = None, - run_kv_cache_validation: bool = True, + max_logits_difference_threshold: Optional[float] = None, + run_cache_validation: bool = True, ): ( @@ -450,42 +615,40 @@ def _test_output( generated_text, ) = torch_ground_truth - # concatenate target prompt_logits and generated_logits + # concatenate target prompt_logits and generated_logits and check target_logits = numpy.concatenate([prompt_logits, generated_logits], axis=1) - # get the logits of the generated sequence score = output.generations[0].score - if logits_threshold: + if max_logits_difference_threshold: # if comparing the output from the model where # the kv cache has been filled, we expect the # maximum absolute difference between the logits # to be less than the threshold # (the threshold is established by running the # ONNX model in ONNXRuntime) - target_logits = target_logits[0] - if target_logits.shape[0] < score.shape[0]: - score = score[: target_logits.shape[0], :] - assert abs(score - target_logits).max() < logits_threshold + assert abs(score - target_logits[0]).max() < max_logits_difference_threshold else: # otherwise, we expect the logits to be exactly the same # as the target logits; the generated sequence should - # also be the same as the target sequence - assert numpy.allclose(score, target_logits[0], atol=self.precision) + # also be the same as the target sequence, and finally + # (if applicable) the kv cache should be the same as the + # target kv cache + + assert numpy.allclose(score, target_logits[0], atol=_PRECISION) assert self.prompt + output.generations[0].text == generated_text - if hasattr(output, "kv_cache_state") and run_kv_cache_validation: - # (if applicable) the kv cache should be the same as the - # target kv cache - expected_cache = list(output.kv_cache_state[0].values()) + if run_cache_validation: + # extract numpy arrays from cached_inputs + kv_cache_array = list(output.kv_cache_state[0].values()) total_num_processed_tokens = output.total_num_processed_tokens[0] self._test_kv_cache_state( - expected_cache=expected_cache, - target_cache=prompt_kv_cache, + expected_cache=kv_cache_array, + target_cache=torch_ground_truth[2], total_num_processed_tokens=total_num_processed_tokens, ) + @staticmethod def _test_kv_cache_state( - self, expected_cache: List[numpy.ndarray], target_cache: List[numpy.ndarray], total_num_processed_tokens: int, @@ -500,5 +663,25 @@ def _test_kv_cache_state( # as target_cache only pertains to prompt cache entries, we need to # compare only the prompt cache entries in x with y assert numpy.allclose( - x[:, :, -start_index:-end_index, :], y, atol=self.precision + x[:, :, -start_index:-end_index, :], y, atol=_PRECISION ) + + def test_streaming_mode_returns_generator(self, setup): + pipeline = self.get_pipeline( + task=self.pipeline_type, + model_path=self.model_stub, + sequence_length=self.sequence_length, + prompt_sequence_length=1, + ) + inputs = dict(prompt=self.prompt, streaming=True) + response_generator = pipeline(**inputs) + + assert inspect.isgenerator( + response_generator + ), "Pipeline should return a generator in streaming mode" + + assert all( + isinstance(response, pipeline.output_schema) + for response in response_generator + ), "Pipeline should return a generator of output_schema \ + objects in streaming mode" From 4aa3c663ae9593ffb7494b731e70cdf5986d80b9 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 18 Oct 2023 06:39:55 +0000 Subject: [PATCH 4/5] revert changes2 --- .../transformers/pipelines/configs/__init__.py | 13 ------------- .../transformers/pipelines/configs/gpt_neo.yaml | 13 ------------- 2 files changed, 26 deletions(-) delete mode 100644 tests/deepsparse/transformers/pipelines/configs/__init__.py delete mode 100644 tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml diff --git a/tests/deepsparse/transformers/pipelines/configs/__init__.py b/tests/deepsparse/transformers/pipelines/configs/__init__.py deleted file mode 100644 index 0c44f887a4..0000000000 --- a/tests/deepsparse/transformers/pipelines/configs/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml b/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml deleted file mode 100644 index df209b5bee..0000000000 --- a/tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml +++ /dev/null @@ -1,13 +0,0 @@ -model_path: "hf:mgoin/TinyStories-1M-deepsparse" -model_name: "roneneldan/TinyStories-1M" -pipeline_type: ["text-generation", "chat"] -num_tokens_generate: 128 -prompt: "Didn't know what time it was, the lights were low\n I leaned back on my radio\n Some cat was layin' down some rock 'n' roll\n \"Lotta soul,\" he said\n Then the loud sound did seem to fade\n Came back like a slow voice on a wave of phase\n That weren't no DJ, that was hazy cosmic jive" -has_bos_token: False -logits_threshold: 24.7 -precision: 0.001 -cache_management_type: - - "internal" - - "external" -run_helper_tests: True -cadence: "commit" \ No newline at end of file From 3bf37cdd56d69b267b6cd24eff21f77be4a57726 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 18 Oct 2023 13:40:58 +0000 Subject: [PATCH 5/5] rahuls suggestion --- .../transformers/engines/nl_decoder_engine.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index dabf520ff3..2670a315e9 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -154,30 +154,29 @@ def run( :return: The output of the engine """ - if kv_cache is not None: - # run the engine assuming kv cache support - if bool(kv_cache.engine_internal_cache): - # run the engine assuming internal kv cache - # management. In this case the LIB.kv_cache - # class object will be passed to the engine - # call as well - # conventionally, before dispatching - # inputs to the engine, we validate them - # if val_inp=True. However, in this case - # we want to pass the empty kv cache inputs - # (batch_size=0) to the engine. Therefore, - # we skip the validation - return self.engine._eng_net.execute_list_out( - inputs, kv_cache.engine_internal_cache - ) - else: - # run the engine assuming external kv cache - # management. - return self.engine.run(inputs, val_inp, kv_cache) - else: + if kv_cache is None: # run the engine without the kv cache support return self.engine.run(inputs, val_inp) + if bool(kv_cache.engine_internal_cache): + # run the engine assuming internal kv cache + # management. In this case the LIB.kv_cache + # class object will be passed to the engine + # call as well + # conventionally, before dispatching + # inputs to the engine, we validate them + # if val_inp=True. However, in this case + # we want to pass the empty kv cache inputs + # (batch_size=0) to the engine. Therefore, + # we skip the validation + return self.engine._eng_net.execute_list_out( + inputs, kv_cache.engine_internal_cache + ) + else: + # run the engine assuming external kv cache + # management. + return self.engine.run(inputs, val_inp, kv_cache) + def __call__( self, inp: List[numpy.ndarray],