Skip to content

Commit

Permalink
ready for review
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Oct 17, 2023
1 parent 380869a commit 2a40416
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,9 @@ def process_engine_outputs(
created=datetime.datetime.now(), prompts=prompts, generations=generations
)

if "session_ids" in kwargs:
outputs["session_ids"] = kwargs["session_ids"]

if self._debug:
debug_params = dict(
kv_cache_state=kv_cache_state,
Expand Down
10 changes: 9 additions & 1 deletion tests/deepsparse/transformers/pipelines/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def parse_params(config_path: str) -> Tuple[Optional[Dict], Optional[str]]:
config = yaml.safe_load(f)

cadence = os.environ.get("CADENCE", "commit")
if cadence == config["cadence"]:
expected_cadence = config["cadence"]

if not isinstance(expected_cadence, list):
expected_cadence = [expected_cadence]
if cadence in expected_cadence:
return config, None
return None, "Skipping test for cadence: {}".format(config["cadence"])

Expand Down Expand Up @@ -129,3 +133,7 @@ def wrapper(self, setup):
return test_method(self, setup)

return wrapper


def find_closest_number_divisible_by_four(number):
return number - (number % 4)
32 changes: 23 additions & 9 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A sample config file requires the following arguments:
This test suite consumes config files to test the text generation pipeline
for various scenerios.
A sample config file is a yaml that requires the following fields:
model_path: The path to the model to be tested
(sparsezoo stub/hf model path/local_path)
model_name: The name of the hugging face model
Expand All @@ -35,10 +38,11 @@
The available options are: "internal" and "external".
run_helper_tests: Whether to run the helper test for the pipeline. Helper tests
check functionalities of the pipeline that are not directly
on the hot path.
on the hot path. They are decorated with @helper_test.
cadence: The cadence of the tests. The available options are:
"nightly" and "commit". By default, only the tests that have cadence
"commit" will be run in GHA.
"commit" will be run in GHA. This parameter can be both a string
or a list of strings.
"""
import inspect
from typing import List, Optional, Tuple
Expand All @@ -52,15 +56,15 @@
from deepsparse.transformers.utils.helpers import prepends_bos_token
from tests.deepsparse.transformers.pipelines.helpers import (
TorchGroundTruthSource,
find_closest_number_divisible_by_four,
helper_test,
parse_params,
validate_cache_management_type,
)


# the user can specify the config file to be used for the tests
# TODO: add more configs
# TODO: add explanation
# TODO: add more configs once the PRs is reviewed
AVAILABLE_CONFIGS = [
"tests/deepsparse/transformers/pipelines/configs/gpt_neo.yaml",
# "tests/deepsparse/transformers/pipelines/configs/text_generation_opt.yaml",
Expand Down Expand Up @@ -150,8 +154,8 @@ def setup(self, config, internal_kv_cache, pipeline_type):
# set the params_dict as the class attributes
for key, value in params_dict.items():
setattr(self, key, value)
# check whether the internal kv cache is supported for testing
# (skip if not supported)
# check whether the specified cache management type
# is supported for testing (skip if not supported)
self.internal_kv_cache: bool = validate_cache_management_type(
internal_kv_cache, self.cache_management_type
)
Expand All @@ -178,11 +182,17 @@ def setup(self, config, internal_kv_cache, pipeline_type):

# prompt_sequence_length used for the multi-token prefill scenario
self.prompt_sequence_length = prompt_length // 4
# TODO: Per @tlrmchlsmth, the prompt_sequence_length must be divisible by 4
# to be changed soon
# (at least for now)
self.prompt_sequence_length = find_closest_number_divisible_by_four(
self.prompt_sequence_length
)
assert self.prompt_sequence_length < prompt_length, (
"The prompt processing sequence length "
"must be smaller than the prompt length"
)

# specify the default pipeline kwargs
self.default_pipeline_kwargs = dict(
task=pipeline_type,
model_path=self.model_path,
Expand Down Expand Up @@ -237,7 +247,10 @@ def test_ort_multi_token_prefill(self, setup):
pipeline._debug = True
output = self.run_pipeline(pipeline)

assert output.total_num_processed_tokens[0] < self.sequence_length
assert output.total_num_processed_tokens[0] < self.sequence_length, (
"The total number of processed tokens must be smaller than the "
"sequence length"
)
self._test_output(
output=output,
torch_ground_truth=self.torch_ground_truth,
Expand Down Expand Up @@ -341,6 +354,7 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup):
output=output,
torch_ground_truth=self.torch_ground_truth,
logits_threshold=self.logits_threshold,
# disable kv cache validation if using internal kv cache
run_kv_cache_validation=not self.internal_kv_cache,
)

Expand Down

0 comments on commit 2a40416

Please sign in to comment.