Skip to content

Commit

Permalink
tiktoken
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Mar 15, 2024
1 parent 675726c commit 30ac9dd
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

FMEngine is our opinionated take on foundation model training framework. The first version of FMEngine is built on top of `PyTorch` and `DeepSpeed` and is designed to be a drop-in replacement for `DeepSpeed` with a few additional features. In the `v2` version we forked from HuggingFace's `nanotron` and added some features to make it easier to use.


# Credits

We would like to thank everyone working on LLMs, especially those sharing their work openly from which we took great inspiration:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ data:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: default
hf_dataset_or_datasets: cerebras/SlimPajama-627B
hf_dataset_or_datasets: DKYoon/SlimPajama-6B
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
Expand Down Expand Up @@ -51,7 +51,7 @@ model:
sliding_window: 4096
tie_word_embeddings: true
use_cache: true
vocab_size: 32000
vocab_size: 102000
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
Expand All @@ -70,16 +70,17 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_type: openai
tokenizer_max_length: null
tokenizer_name_or_path: mistralai/Mistral-7B-v0.1
tokenizer_name_or_path: cl100k_base
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
Expand Down
6 changes: 2 additions & 4 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from fmengine.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from fmengine.trainer import DistributedTrainer
from fmengine.utils import main_rank_first

from fmengine.tokenizer import get_tokenizer
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
Expand Down Expand Up @@ -76,9 +76,7 @@ def get_dataloader(trainer: DistributedTrainer):
stream=True,
)['train']

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer = get_tokenizer(trainer.config)
# We apply the Causal Language Modeling preprocessing
train_dataset = clm_process(
raw_dataset=raw_dataset,
Expand Down
2 changes: 1 addition & 1 deletion src/fmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __post_init__(self):
@dataclass
class TokenizerArgs:
"""Arguments related to the tokenizer"""

tokenizer_type: Optional[str] = "hf"
tokenizer_name_or_path: Optional[str] = None
tokenizer_revision: Optional[str] = None
tokenizer_max_length: Optional[int] = None
Expand Down
20 changes: 13 additions & 7 deletions src/fmengine/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from transformers import PreTrainedTokenizerBase
from transformers.trainer_pt_utils import DistributedSamplerWithLoop

import tiktoken
logger = logging.get_logger(__name__)

def sanity_check_dataloader(
Expand Down Expand Up @@ -375,12 +375,18 @@ def _group_texts(
return result

def _tokenize_texts(texts: List[str]) -> Dict[str, List[np.ndarray]]:
tokenized_batch = tokenizer.encode(
texts,
return_attention_mask=False,
return_token_type_ids=False,
truncation=True,
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
tokenized_batch = tokenizer.encode(
texts,
return_attention_mask=False,
return_token_type_ids=False,
truncation=True,
)
print(tokenized_batch)
elif isinstance(tokenizer, tiktoken.core.Encoding):
tokenized_batch = tokenizer.encode_batch(texts)
# flatten the list of lists
tokenized_batch = [item for sublist in tokenized_batch for item in sublist]
return {"input_ids": tokenized_batch}

train_dataset = raw_dataset.map(
Expand Down
15 changes: 15 additions & 0 deletions src/fmengine/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def get_tokenizer(trainer_config):
tokenizer_type = trainer_config.tokenizer.tokenizer_type.lower()
tokenizer_path = trainer_config.tokenizer.tokenizer_name_or_path
assert tokenizer_type in ['hf', 'openai'], f"Unknown tokenizer type {tokenizer_type}"
if tokenizer_type == 'openai':
import tiktoken
tokenizer = tiktoken.get_encoding(tokenizer_path)
elif tokenizer_type == 'hf':
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
else:
raise NotImplementedError(f"Tokenizer type {tokenizer_type} not implemented")
return tokenizer

0 comments on commit 30ac9dd

Please sign in to comment.