Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored modules/tokenizers to be a subdir of modules/transforms #2231

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ model specific tokenizers.
:toctree: generated/
:nosignatures:

tokenizers.SentencePieceBaseTokenizer
tokenizers.TikTokenBaseTokenizer
tokenizers.ModelTokenizer
tokenizers.BaseTokenizer
transforms.tokenizers.SentencePieceBaseTokenizer
transforms.tokenizers.TikTokenBaseTokenizer
transforms.tokenizers.ModelTokenizer
transforms.tokenizers.BaseTokenizer

Tokenizer Utilities
-------------------
Expand All @@ -61,8 +61,8 @@ These are helper methods that can be used by any tokenizer.
:toctree: generated/
:nosignatures:

tokenizers.tokenize_messages_no_special_tokens
tokenizers.parse_hf_tokenizer_json
transforms.tokenizers.tokenize_messages_no_special_tokens
transforms.tokenizers.parse_hf_tokenizer_json


PEFT Components
Expand Down
10 changes: 5 additions & 5 deletions docs/source/basics/tokenizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ For example, here we change the ``"<|begin_of_text|>"`` and ``"<|end_of_text|>"`
Base tokenizers
---------------

:class:`~torchtune.modules.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
:class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
In torchtune, they are required to implement ``encode`` and ``decode`` methods, which are called by the :ref:`model_tokenizers` to convert
between raw text and token IDs.

Expand Down Expand Up @@ -202,13 +202,13 @@ between raw text and token IDs.
"""
pass

If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.tokenizers.BaseTokenizer`
If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer`
to do the actual encoding and decoding.

.. code-block:: python

from torchtune.models.mistral import mistral_tokenizer
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer

m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")
# Mistral uses SentencePiece for its underlying BPE
Expand All @@ -227,7 +227,7 @@ to do the actual encoding and decoding.
Model tokenizers
----------------

:class:`~torchtune.modules.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
which converts a list of Messages into a list of token IDs.

.. code-block:: python
Expand Down Expand Up @@ -259,7 +259,7 @@ is because they add all the necessary special tokens or prompt templates require
.. code-block:: python

from torchtune.models.mistral import mistral_tokenizer
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
from torchtune.data import Message

m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")
Expand Down
2 changes: 1 addition & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from torchtune.modules import TransformerDecoder
from torchtune.modules.common_utils import local_kv_cache
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer
from torchtune.recipe_interfaces import EvalRecipeInterface
from torchtune.training import FullModelTorchTuneCheckpointer

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch
from torch import nn
from torchtune.data import Message, PromptTemplate, truncate
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer

skip_if_cuda_not_available = unittest.skipIf(
not torch.cuda.is_available(), "CUDA is not available"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from tests.common import ASSETS
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer


class TestSentencePieceBaseTokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests.common import ASSETS
from torchtune.models.llama3._tokenizer import CL100K_PATTERN
from torchtune.modules.tokenizers import TikTokenBaseTokenizer
from torchtune.modules.transforms.tokenizers import TikTokenBaseTokenizer


class TestTikTokenBaseTokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tests.test_utils import DummyTokenizer
from torchtune.data import Message

from torchtune.modules.tokenizers import tokenize_messages_no_special_tokens
from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens


class TestTokenizerUtils:
Expand Down
7 changes: 4 additions & 3 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
class Message:
"""
This class represents individual messages in a fine-tuning dataset. It supports
text-only content, text with interleaved images, and tool calls. The :class:`~torchtune.modules.tokenizers.ModelTokenizer`
will tokenize the content of the message using ``tokenize_messages`` and attach
the appropriate special tokens based on the flags set in this class.
text-only content, text with interleaved images, and tool calls. The
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
the content of the message using ``tokenize_messages`` and attach the appropriate
special tokens based on the flags set in this class.

Args:
role (Role): role of the message writer. Can be "system" for system prompts,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def alpaca_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def chat_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchtune.datasets._text_completion import TextCompletionDataset

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def cnn_dailymail_articles_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def grammar_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_hh_rlhf_helpful.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchtune.data import ChosenRejectedToMessages
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def hh_rlhf_helpful_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def instruct_dataset(
Expand Down
6 changes: 3 additions & 3 deletions torchtune/datasets/_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from torch.utils.data import Dataset

from torchtune.data import ChosenRejectedToMessages, CROSS_ENTROPY_IGNORE_IDX

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform

from torchtune.modules.transforms.tokenizers import ModelTokenizer


class PreferenceDataset(Dataset):
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
of messages are stored in the ``"chosen"`` and ``"rejected"`` keys.
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
Since PreferenceDataset only supports text data, it requires a
:class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
:class:`~torchtune.datasets.SFTDataset`.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_samsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def samsum_dataset(
Expand Down
12 changes: 7 additions & 5 deletions torchtune/datasets/_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ class SFTDataset(Dataset):
multimodal datasets requires processing the images in a way specific to the vision
encoder being used by the model and is agnostic to the specific dataset.

Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
transform the list of messages outputted from the ``message_transform`` into tokens
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
into ``model_transform``. Tokenizers handle prompt templating, if configured.
Tokenization is handled by the ``model_transform``. All
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` can be treated as
a ``model_transform`` since it uses the model-specific tokenizer to transform the
list of messages outputted from the ``message_transform`` into tokens used by the
model for training. Text-only datasets will simply pass the
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` into ``model_transform``.
Tokenizers handle prompt templating, if configured.

Args:
source (str): path to dataset repository on Hugging Face. For local datasets,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_slimorca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtune.datasets._packed import PackedDataset

from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def slimorca_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_stack_exchange_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from torchtune.data import Message
from torchtune.datasets._preference import PreferenceDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class StackExchangePairedToMessages(Transform):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import Dataset
from torchtune.data._utils import truncate
from torchtune.datasets._packed import PackedDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class TextCompletionDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion torchtune/datasets/_wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TextCompletionDataset,
)

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms.tokenizers import ModelTokenizer


def wikitext_dataset(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/clip/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import regex as re

from torchtune.modules.tokenizers._utils import BaseTokenizer
from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer

WORD_BOUNDARY = "</w>"

Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/gemma/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Any, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/llama2/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from torchtune.data import Message, PromptTemplate
from torchtune.models.llama2._prompt_template import Llama2ChatTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json


"""
Expand Down
5 changes: 4 additions & 1 deletion torchtune/models/llama3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple

from torchtune.data import Message, PromptTemplate, truncate
from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
TikTokenBaseTokenizer,
)


CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa
Expand Down
1 change: 1 addition & 0 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchtune.models.llama3_2_vision._transform import Llama3VisionTransform
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json


def llama3_2_vision_transform(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3_2_vision/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from torchtune.models.clip import CLIPImageTransform
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform, VisionCrossAttentionMask
from torchtune.modules.transforms.tokenizers import ModelTokenizer


class Llama3VisionTransform(ModelTokenizer, Transform):
Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/mistral/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from torchtune.data import Message, PromptTemplate
from torchtune.models.mistral._prompt_template import MistralChatTemplate
from torchtune.modules.tokenizers import (
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
tokenize_messages_no_special_tokens,
)
from torchtune.modules.transforms import Transform

WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]

Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/phi3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from functools import partial
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json
from torchtune.data._prompt_templates import _TemplateType
from torchtune.data._prompt_templates import _get_prompt_template

Expand Down
5 changes: 4 additions & 1 deletion torchtune/models/phi3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from torchtune.data._messages import Message
from torchtune.data._prompt_templates import PromptTemplate
from torchtune.data._utils import truncate
from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer
from torchtune.modules.transforms import Transform
from torchtune.modules.transforms.tokenizers import (
ModelTokenizer,
SentencePieceBaseTokenizer,
)

PHI3_SPECIAL_TOKENS = {
"<|endoftext|>": 32000,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/qwen2/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchtune.models.qwen2._tokenizer import QWEN2_SPECIAL_TOKENS, Qwen2Tokenizer
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json

"""
Model builders build specific instantiations using component builders. For example
Expand Down
Loading
Loading