Skip to content

Commit

Permalink
Merge branch 'main' into feature/damian/final_llm_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored Oct 19, 2023
2 parents 5f8357f + 9d0d897 commit 8cca910
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
14 changes: 12 additions & 2 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,15 @@ def run(
:return: The output of the engine
"""
if kv_cache is None:
# run the engine without the kv cache support
return self.engine.run(inputs, val_inp)

if bool(kv_cache.engine_internal_cache):
# run the engine assuming internal kv cache
# management. In this case the LIB.kv_cache
# class object will be passed to the engine
# call as well
# conventionally, before dispatching
# inputs to the engine, we validate them
# if val_inp=True. However, in this case
Expand All @@ -164,8 +172,10 @@ def run(
return self.engine._eng_net.execute_list_out(
inputs, kv_cache.engine_internal_cache
)
# run the engine without the LIB.kv_cache object
return self.engine.run(inputs, val_inp)
else:
# run the engine assuming external kv cache
# management.
return self.engine.run(inputs, val_inp, kv_cache)

def __call__(
self,
Expand Down
14 changes: 13 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ def __init__(
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

# the current requirement on the deepsparse engine
# is that prompt_sequence_length
# must be 1 or a multiple of four.
# for simplicity let's extend this requirement to all engines
if (prompt_sequence_length % 4 != 0) and (prompt_sequence_length != 1):
raise ValueError(
f"prompt_sequence_length must be 1 or multiple of 4. "
f"prompt_sequence_length is {prompt_sequence_length}"
)
self.prompt_sequence_length = prompt_sequence_length
self.force_max_tokens = force_max_tokens
self.internal_kv_cache = internal_kv_cache
Expand Down Expand Up @@ -686,7 +695,10 @@ def engine_forward(
)
for prompt_logit in prompt_logits:
token_generator.generate(prompt_logit)
return numpy.array([self.tokens]), prompt_logits
yield numpy.array([token_generator.tokens]), prompt_logits, [
FinishReason.LENGTH
]
return

else:
# run the prompt through
Expand Down

0 comments on commit 8cca910

Please sign in to comment.