From 39e21d3b86ded2c04929ada042adb667e2537032 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 30 Nov 2023 15:00:44 +0100 Subject: [PATCH] [Cherry-Pick] Fix the token_generator behavior for non-kv-cache models (#1441) --- .../transformers/utils/token_generator.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/deepsparse/transformers/utils/token_generator.py b/src/deepsparse/transformers/utils/token_generator.py index c2eace3654..65e1f8d16a 100644 --- a/src/deepsparse/transformers/utils/token_generator.py +++ b/src/deepsparse/transformers/utils/token_generator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import numpy @@ -32,7 +32,7 @@ class TokenGenerator: def __init__( self, logits_shape: int, - tokens: List[int] = [], + tokens: Optional[List[int]] = None, deterministic: bool = True, sampling_temperature: float = 1.0, top_k: int = 0, @@ -64,7 +64,7 @@ def __init__( self.top_p = top_p self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty - self.tokens = tokens + self.tokens = [] if tokens is None else tokens self._initialize_token_frequencies() @@ -77,11 +77,16 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray: :param logits: the logits from the model with shape (vocab_size,) :return: the sampled token """ + if self.deterministic: token = numpy.argmax(logits) self.tokens.append(token) return token + # make a copy of logits to avoid modifying the original + # logits distribution in-place + logits = logits.copy() + if self.top_k: logits = self.apply_top_k(logits) @@ -173,5 +178,5 @@ def _update_frequencies(self, token: numpy.ndarray): def _initialize_token_frequencies(self): unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True) - for token, frequnecies in zip(unique_tokens, frequencies): - self.token_frequencies[token] += frequnecies + for token, freq in zip(unique_tokens, frequencies): + self.token_frequencies[token] += freq