Skip to content

Commit

Permalink
add llama tokenizer fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Mar 14, 2024
1 parent e09ae26 commit fd3b0d8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
else [],
"finished_reason": [],
"token_generator": token_generator,
"past_tokens_queue": copy.copy(tokens),
}

if kv_cache is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
from typing import List, Optional

import numpy

Expand Down Expand Up @@ -54,6 +54,33 @@ def _create_generated_text_output(
finished=False,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.
:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return string_from_n_plus_1_tokens[len(string_from_n_tokens) :]

def run(
self,
generated_tokens: numpy.ndarray,
Expand All @@ -64,9 +91,17 @@ def run(
):
generation_config = inference_state.current_state.get("generation_config")
generated_logits = generated_logits if generation_config.output_scores else None
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

import transformers
if isinstance(self.tokenizer, (transformers.LlamaTokenizer, transformers.LlamaTokenizerFast)):
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
sequences = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

try:
finished_reason = [f[-1] for f in finished_reason]
Expand Down

0 comments on commit fd3b0d8

Please sign in to comment.