diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 9c93616053..2670a315e9 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -154,7 +154,15 @@ def run( :return: The output of the engine """ + if kv_cache is None: + # run the engine without the kv cache support + return self.engine.run(inputs, val_inp) + if bool(kv_cache.engine_internal_cache): + # run the engine assuming internal kv cache + # management. In this case the LIB.kv_cache + # class object will be passed to the engine + # call as well # conventionally, before dispatching # inputs to the engine, we validate them # if val_inp=True. However, in this case @@ -164,8 +172,10 @@ def run( return self.engine._eng_net.execute_list_out( inputs, kv_cache.engine_internal_cache ) - # run the engine without the LIB.kv_cache object - return self.engine.run(inputs, val_inp) + else: + # run the engine assuming external kv cache + # management. + return self.engine.run(inputs, val_inp, kv_cache) def __call__( self, diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 6947817f0a..3ef82ca7f3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -252,6 +252,15 @@ def __init__( if "WAND_OPT_FLAGS" not in os.environ: os.environ["WAND_OPT_FLAGS"] = "default,~pyramids" + # the current requirement on the deepsparse engine + # is that prompt_sequence_length + # must be 1 or a multiple of four. + # for simplicity let's extend this requirement to all engines + if (prompt_sequence_length % 4 != 0) and (prompt_sequence_length != 1): + raise ValueError( + f"prompt_sequence_length must be 1 or multiple of 4. " + f"prompt_sequence_length is {prompt_sequence_length}" + ) self.prompt_sequence_length = prompt_sequence_length self.force_max_tokens = force_max_tokens self.internal_kv_cache = internal_kv_cache @@ -686,7 +695,10 @@ def engine_forward( ) for prompt_logit in prompt_logits: token_generator.generate(prompt_logit) - return numpy.array([self.tokens]), prompt_logits + yield numpy.array([token_generator.tokens]), prompt_logits, [ + FinishReason.LENGTH + ] + return else: # run the prompt through