From 9d0d89797a375057042caf39f607b6d1de0762da Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 19 Oct 2023 16:11:10 +0200 Subject: [PATCH] [Text Generation] Fail if `prompt_sequence_length % 4 != 0` (#1332) * cleanup * thanks michael --- src/deepsparse/transformers/pipelines/text_generation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 9577b482f3..0fdf66fb5f 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -252,6 +252,15 @@ def __init__( if "WAND_OPT_FLAGS" not in os.environ: os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + # the current requirement on the deepsparse engine + # is that prompt_sequence_length + # must be 1 or a multiple of four. + # for simplicity let's extend this requirement to all engines + if (prompt_sequence_length % 4 != 0) and (prompt_sequence_length != 1): + raise ValueError( + f"prompt_sequence_length must be 1 or multiple of 4. " + f"prompt_sequence_length is {prompt_sequence_length}" + ) self.prompt_sequence_length = prompt_sequence_length self.force_max_tokens = force_max_tokens self.internal_kv_cache = internal_kv_cache