diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 9c93616053..2670a315e9 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -154,7 +154,15 @@ def run( :return: The output of the engine """ + if kv_cache is None: + # run the engine without the kv cache support + return self.engine.run(inputs, val_inp) + if bool(kv_cache.engine_internal_cache): + # run the engine assuming internal kv cache + # management. In this case the LIB.kv_cache + # class object will be passed to the engine + # call as well # conventionally, before dispatching # inputs to the engine, we validate them # if val_inp=True. However, in this case @@ -164,8 +172,10 @@ def run( return self.engine._eng_net.execute_list_out( inputs, kv_cache.engine_internal_cache ) - # run the engine without the LIB.kv_cache object - return self.engine.run(inputs, val_inp) + else: + # run the engine assuming external kv cache + # management. + return self.engine.run(inputs, val_inp, kv_cache) def __call__( self, diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 06c1d9750c..9577b482f3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -683,7 +683,10 @@ def engine_forward( ) for prompt_logit in prompt_logits: token_generator.generate(prompt_logit) - return numpy.array([self.tokens]), prompt_logits + yield numpy.array([token_generator.tokens]), prompt_logits, [ + FinishReason.LENGTH + ] + return else: # run the prompt through