Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Fix llama tokenizer (#1635) #1636

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
else [],
"finished_reason": [],
"token_generator": token_generator,
"past_tokens_queue": copy.copy(tokens),
}

if kv_cache is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
from typing import List, Optional

import numpy

Expand Down Expand Up @@ -54,6 +54,33 @@ def _create_generated_text_output(
finished=False,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.

:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]

def run(
self,
generated_tokens: numpy.ndarray,
Expand All @@ -64,9 +91,24 @@ def run(
):
generation_config = inference_state.current_state.get("generation_config")
generated_logits = generated_logits if generation_config.output_scores else None
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

import transformers

# Fix for LLAMA-specific models when running streaming
# TODO: make streaming a conditional input to this operator. using inference
# state is a quick fix.
if isinstance(
self.tokenizer,
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
) and inference_state.current_state.get("streaming"):
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
sequences = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

try:
finished_reason = [f[-1] for f in finished_reason]
Expand Down
Loading