Skip to content

Commit

Permalink
[TextGeneration] Fix llama tokenizer (#1635)
Browse files Browse the repository at this point in the history
* add llama tokenizer fix

* fix generated string

* only run for streaming

* add TODO

---------

Co-authored-by: Dipika Sikka <dipikasikka1@gmail.coom>
  • Loading branch information
dsikka and Dipika Sikka authored Mar 14, 2024
1 parent e09ae26 commit 9bac61e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
else [],
"finished_reason": [],
"token_generator": token_generator,
"past_tokens_queue": copy.copy(tokens),
}

if kv_cache is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
from typing import List, Optional

import numpy

Expand Down Expand Up @@ -54,6 +54,33 @@ def _create_generated_text_output(
finished=False,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.
:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]

def run(
self,
generated_tokens: numpy.ndarray,
Expand All @@ -64,9 +91,24 @@ def run(
):
generation_config = inference_state.current_state.get("generation_config")
generated_logits = generated_logits if generation_config.output_scores else None
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

import transformers

# Fix for LLAMA-specific models when running streaming
# TODO: make streaming a conditional input to this operator. using inference
# state is a quick fix.
if isinstance(
self.tokenizer,
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
) and inference_state.current_state.get("streaming"):
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
sequences = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

try:
finished_reason = [f[-1] for f in finished_reason]
Expand Down

0 comments on commit 9bac61e

Please sign in to comment.