Skip to content

Commit

Permalink
Update HF runner configs (#110)
Browse files Browse the repository at this point in the history
Added a Llama 3.1 405B configuration.

Also fixed some comments in the config yaml.
In particular for per_device_train_batch_size, I based the comment on
http://shortn/_SR8wqscDlo.

Also changes Dockerfile to install all transformer deps, because the
local_transformer may be newer than our transformer dependency. In that
case it may require a newer dependency such as tokenizers==0.20 instead
of tokenizers==0.19.
  • Loading branch information
tengyifei authored Feb 13, 2025
1 parent ed8655d commit 07f3dfe
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
10 changes: 7 additions & 3 deletions torchprime/hf_models/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# note: the sharding annonation is currently hard coded in pytorch-tpu/transformers
# note: the sharding annotation is currently hard coded in pytorch-tpu/transformers
defaults:
- _self_

Expand All @@ -15,7 +15,11 @@ train_script:
args:
dataset_name: "wikitext"
dataset_config_name: "wikitext-103-raw-v1"
per_device_train_batch_size: 256 # this is global batch size if use minibatch

# If minibatch is False, this should be set to the global batch size.
# If minibatch is True, this should be set to the per host batch size.
per_device_train_batch_size: 256

do_train: true
output_dir: "test-clm"
overwrite_output_dir: true
Expand All @@ -32,4 +36,4 @@ train_script:
dataloader_drop_last: true
flash_attention: true
max_steps: 50
seed: 42
seed: 42
34 changes: 34 additions & 0 deletions torchprime/hf_models/configs/model/llama-3/config_405b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 16384,
"initializer_range": 0.02,
"intermediate_size": 53248,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 128,
"num_hidden_layers": 126,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.42.3",
"use_cache": false,
"vocab_size": 128256
}
3 changes: 3 additions & 0 deletions torchprime/hf_models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def build_command(config: DictConfig) -> list:
args[k] = v

for k, v in args.items():
if v is None:
# We may delete an argument by setting it to `null` on the CLI.
continue
if isinstance(v, bool):
if v:
cmd.append(f"--{k}")
Expand Down
3 changes: 1 addition & 2 deletions torchprime/launcher/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ RUN if [ "$USE_TRANSFORMERS" = "true" ] && [ -d "local_transformers" ]; then \

# Only install transformers if USE_TRANSFORMERS is true
RUN if [ "$USE_TRANSFORMERS" = "true" ]; then \
pip install --no-deps -e /workspaces/torchprime/local_transformers; \
pip install --no-deps evaluate; \
pip install -e /workspaces/torchprime/local_transformers evaluate; \
fi

ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"

0 comments on commit 07f3dfe

Please sign in to comment.