Skip to content

Commit

Permalink
[Text Generation] Fail if prompt_sequence_length % 4 != 0 (#1332)
Browse files Browse the repository at this point in the history
* cleanup

* thanks michael
  • Loading branch information
dbogunowicz authored Oct 19, 2023
1 parent 3c80189 commit 9d0d897
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ def __init__(
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

# the current requirement on the deepsparse engine
# is that prompt_sequence_length
# must be 1 or a multiple of four.
# for simplicity let's extend this requirement to all engines
if (prompt_sequence_length % 4 != 0) and (prompt_sequence_length != 1):
raise ValueError(
f"prompt_sequence_length must be 1 or multiple of 4. "
f"prompt_sequence_length is {prompt_sequence_length}"
)
self.prompt_sequence_length = prompt_sequence_length
self.force_max_tokens = force_max_tokens
self.internal_kv_cache = internal_kv_cache
Expand Down

0 comments on commit 9d0d897

Please sign in to comment.