From 39e21d3b86ded2c04929ada042adb667e2537032 Mon Sep 17 00:00:00 2001
From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Date: Thu, 30 Nov 2023 15:00:44 +0100
Subject: [PATCH] [Cherry-Pick] Fix the token_generator behavior for
 non-kv-cache models (#1441)

---
 .../transformers/utils/token_generator.py         | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py
index c2eace3654..65e1f8d16a 100644
--- a/src/deepsparse/transformers/utils/token_generator.py
+++ b/src/deepsparse/transformers/utils/token_generator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List
+from typing import List, Optional
 
 import numpy
 
@@ -32,7 +32,7 @@ class TokenGenerator:
     def __init__(
         self,
         logits_shape: int,
-        tokens: List[int] = [],
+        tokens: Optional[List[int]] = None,
         deterministic: bool = True,
         sampling_temperature: float = 1.0,
         top_k: int = 0,
@@ -64,7 +64,7 @@ def __init__(
         self.top_p = top_p
         self.frequency_penalty = frequency_penalty
         self.presence_penalty = presence_penalty
-        self.tokens = tokens
+        self.tokens = [] if tokens is None else tokens
 
         self._initialize_token_frequencies()
 
@@ -77,11 +77,16 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
         :param logits: the logits from the model with shape (vocab_size,)
         :return: the sampled token
         """
+
         if self.deterministic:
             token = numpy.argmax(logits)
             self.tokens.append(token)
             return token
 
+        # make a copy of logits to avoid modifying the original
+        # logits distribution in-place
+        logits = logits.copy()
+
         if self.top_k:
             logits = self.apply_top_k(logits)
 
@@ -173,5 +178,5 @@ def _update_frequencies(self, token: numpy.ndarray):
 
     def _initialize_token_frequencies(self):
         unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True)
-        for token, frequnecies in zip(unique_tokens, frequencies):
-            self.token_frequencies[token] += frequnecies
+        for token, freq in zip(unique_tokens, frequencies):
+            self.token_frequencies[token] += freq