diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index 23dbd3a82..6e8b6e39d 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -167,26 +167,27 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: f"layers.{layer_idx}.feed_forward.w3.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight", } + def _hf_to_maxtext_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict: - # pylint: disable=line-too-long + # pylint: disable=line-too-long return { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight", - f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight", - f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight", - f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight", - f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight", - f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight", - # MOE model - f"model.layers.{layer_idx}.block_sparse_moe.gate.weight": f"layers.{layer_idx}.feed_forward.gate.weight", - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight", - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight", - f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight", - f"model.layers.{layer_idx}.mlp.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight", - f"model.layers.{layer_idx}.mlp.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight", - f"model.layers.{layer_idx}.mlp.up_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight" + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight", + f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight", + f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight", + f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight", + f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight", + f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight", + # MOE model + f"model.layers.{layer_idx}.block_sparse_moe.gate.weight": f"layers.{layer_idx}.feed_forward.gate.weight", + f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight", + f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight", + f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight", + f"model.layers.{layer_idx}.mlp.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight", + f"model.layers.{layer_idx}.mlp.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight", + f"model.layers.{layer_idx}.mlp.up_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight", } @@ -220,6 +221,7 @@ def permute_to_match_maxtext_rope(arr): x[..., 1::2] = odds return x + def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info): base_num_decoder_layers = model_params["num_layers"] base_num_query_heads = model_params["num_heads"] @@ -239,7 +241,7 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ parts = key.split(".") layer = int(parts[2]) if "layers" in key else 0 mapped_key = _hf_to_maxtext_mapping(layer)[key] - chkpt_vars[mapped_key]=f.get_tensor(key) + chkpt_vars[mapped_key] = f.get_tensor(key) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) @@ -269,13 +271,14 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ # logits dense ################################################# max_logging.log("Processing logits dense") - jax_weights["decoder"]["logits_dense"]["kernel"] = chkpt_vars["output.weight"].to(torch.float16).numpy().transpose()[:, :vocab_size] + jax_weights["decoder"]["logits_dense"]["kernel"] = ( + chkpt_vars["output.weight"].to(torch.float16).numpy().transpose()[:, :vocab_size] + ) # logits_dense = np.concatenate( # [var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0 # ).transpose()[:, :vocab_size] # jax_weights["decoder"]["logits_dense"]["kernel"] = logits_dense - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # token embedding ############################################## @@ -291,7 +294,9 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ if model_size[:6] == "llama3": jax_weights["token_embedder"]["embedding"] = chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy() else: - jax_weights["token_embedder"]["embedding"] = chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy()[:vocab_size, :] + jax_weights["token_embedder"]["embedding"] = ( + chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy()[:vocab_size, :] + ) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) @@ -522,7 +527,7 @@ def convert_to_jax_weights(base_model_path, model_size, huggingface_ckpt): max_logging.log(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints if huggingface_ckpt: - return convert_huggingface_to_jax_weights( base_model_path, model_size, huggingface_ckpt, model_params, mem_info) + return convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info) chkpt_vars = {} checkpoint = {} ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth")) @@ -831,6 +836,7 @@ def checkpoint_device_put(arr): # Upon preemption, exit when and only when all ongoing saves are complete. checkpoint_manager.wait_until_finished() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-model-path", type=str, required=True) @@ -843,4 +849,6 @@ def checkpoint_device_put(arr): raise NotImplementedError os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" - save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.huggingface_checkpoint)) + save_jax_weights_to_checkpoint( + args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.huggingface_checkpoint) + ) diff --git a/MaxText/tests/__init__.py b/MaxText/tests/__init__.py new file mode 100644 index 000000000..c133d2d71 --- /dev/null +++ b/MaxText/tests/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/MaxText/tests/hf_checkpoint_conversion_checker.py b/MaxText/tests/hf_checkpoint_conversion_checker.py new file mode 100644 index 000000000..d1ca4ff10 --- /dev/null +++ b/MaxText/tests/hf_checkpoint_conversion_checker.py @@ -0,0 +1,167 @@ +import glob +import os +import torch +from safetensors import safe_open +import pathlib +import jax +import jax.numpy as jnp +import numpy as np +import torch +from MaxText.llama_or_mistral_ckpt import permute_to_match_maxtext_rope +from MaxText.llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope +# from .. import llama_or_mistral_ckpt +# from .. import llama_mistral_mixtral_orbax_to_hf +# import llama_or_mistral_ckpt +# import llama_mistral_mixtral_orbax_to_hf +import sys +import jax +from jax.sharding import Mesh +from jax.experimental import mesh_utils +import argparse +import pyconfig +import pytest + +import unittest + + +def load_hf(hf_checkpoint_folder): + safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors")) + + hf_tensor = {} + for st_f in safetensor_files: + with safe_open(st_f, framework="pt", device="cpu") as f: + for key in f.keys(): + hf_tensor[key] = f.get_tensor(key).to(torch.float16) + # print(f"Weight name {key}, Shape: {hf_tensor.shape}, dtype: {hf_tensor[key].dtype}") + return hf_tensor + + +def load_meta(meta_checkpoint_folder): + meta_tensor = {} + ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth")) + for i, ckpt_path in enumerate(ckpt_paths): + meta_tensor = torch.load(ckpt_path, map_location="cpu") + # chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint + # chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))] + return meta_tensor + + +def compare_pytrees(tree1, tree2, atol=0.001): + """ + Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance. + + Args: + tree1: First pytree. + tree2: Second pytree. + atol: Absolute tolerance for comparison (default: 0.001). + + Returns: + A boolean indicating if all leaf values are within the specified range. + """ + # Ensure both trees have the same structure + if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2): + print( + f"Pytrees have different structures! Tree1: {jax.tree_util.tree_structure(tree1)} \n\n\nTree2: {jax.tree_util.tree_structure(tree2)}" + ) + return False + + # Compare leaves with names + def get_named_leaves(pytree, parent_key=""): + named_leaves = {} + for key, value in pytree.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + named_leaves.update(get_named_leaves(value, new_key)) + else: + named_leaves[new_key] = value + return named_leaves + + named_leaves1 = get_named_leaves(tree1) + named_leaves2 = get_named_leaves(tree2) + + for key in named_leaves1: + # import pdb; pdb.set_trace() + if key not in named_leaves2: + print(f"Missing key in second tree: {key}") + return False + try: + if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol): + # print(f"Mismatch at leaf '{key}':\n{named_leaves1[key]}\n{named_leaves2[key]}") + # return False + # print(f"Mismatch at leaf '{key}'") + mismatch_values1 = named_leaves1[key].flatten()[:10] + mismatch_values2 = named_leaves2[key].flatten()[:10] + # print(f"Mismatch at leaf '{key}':\nFirst 10 elements:\n{mismatch_values1}\n{mismatch_values2}") + print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n") + for i in range(10): + print(f"{named_leaves1[key][..., i, :]}\n") + print(f"The second tensor:\n") + for i in range(10): + print(f"{named_leaves2[key][..., i, :]}\n") + return + except: + print(f"The issue is with {key}") + # print(f"Checking {key} done") + + print("All leaves match within tolerance.") + return True + + +def test_huggingface_to_maxtext_back_to_huggingface_flow(): + base_num_query_heads = base_num_kv_heads = 16 + head_dim = 32 + # import pdb; pdb.set_trace() + wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape( + base_num_query_heads * head_dim, base_num_query_heads * head_dim + ) + wq1 = wq.transpose() + wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) + + wq3 = llama_or_mistral_ckpt.permute_to_match_maxtext_rope(wq2) + stack_shape = (1,) + x = np.zeros(stack_shape + wq3.shape, dtype=np.float16) + x[0, ...] = wq3 + x = np.transpose(x, axes=(1, 0, 2, 3)) + + x = x[:, 0, :, :] + wq4 = llama_mistral_mixtral_orbax_to_hf.unpermute_from_match_maxtext_rope(x, "llama3.1") + wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) + wq6 = wq5.transpose() + + if not np.array_equal(wq, wq6): + print("Test failed: wq does not match wq6") + + if not np.array_equal(wq1, wq5): + print("Test failed: wq1 does not match wq5") + + if not np.array_equal(wq2, wq4): + print("Test failed: wq2 does not match wq4") + + +def main(): + parser = argparse.ArgumentParser(description="Compares the original checkpoint and converted back checkpoint.") + parser.add_argument( + "--original_ckpt", + type=str, + default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B", + help="The original huggingface checkpoint", + ) + parser.add_argument( + "--converted_ckpt", + type=str, + default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B/", + help="The original huggingface checkpoint", + ) + args = parser.parse_args() + + hf_checkpoint_folder = args.original_ckpt + hf_tensor = load_hf(hf_checkpoint_folder) + + meta_checkpoint_folder = args.converted_ckpt + meta_tensor = load_hf(meta_checkpoint_folder) + + compare_pytrees(hf_tensor, meta_tensor) + + +if __name__ == "__main__": + main() diff --git a/MaxText/tests/hf_checkpoint_conversion_test.py b/MaxText/tests/hf_checkpoint_conversion_test.py deleted file mode 100644 index 3182c95a0..000000000 --- a/MaxText/tests/hf_checkpoint_conversion_test.py +++ /dev/null @@ -1,139 +0,0 @@ -import glob -import os -import torch -from safetensors import safe_open -import pathlib -import jax -import jax.numpy as jnp -import numpy as np -import torch -from llama_or_mistral_ckpt import permute_to_match_maxtext_rope -from llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope -import sys -import jax -from jax.sharding import Mesh -from jax.experimental import mesh_utils - -import pyconfig -import pytest - -import unittest - -def load_hf(hf_checkpoint_folder): - safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors")) - - hf_tensor = {} - for st_f in safetensor_files: - with safe_open(st_f, framework="pt", device="cpu") as f: - for key in f.keys(): - hf_tensor[key] = f.get_tensor(key).to(torch.float16) - # print(f"Weight name {key}, Shape: {hf_tensor.shape}, dtype: {hf_tensor[key].dtype}") - return hf_tensor - -def load_meta(meta_checkpoint_folder): - meta_tensor = {} - ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth")) - for i, ckpt_path in enumerate(ckpt_paths): - meta_tensor = torch.load(ckpt_path, map_location="cpu") - # chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint - # chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))] - return meta_tensor - - -def compare_pytrees(tree1, tree2, atol=0.001): - """ - Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance. - - Args: - tree1: First pytree. - tree2: Second pytree. - atol: Absolute tolerance for comparison (default: 0.001). - - Returns: - A boolean indicating if all leaf values are within the specified range. - """ - # Ensure both trees have the same structure - if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2): - print(f"Pytrees have different structures! Tree1: {jax.tree_util.tree_structure(tree1)} \n\n\nTree2: {jax.tree_util.tree_structure(tree2)}") - return False - - # Compare leaves with names - def get_named_leaves(pytree, parent_key=""): - named_leaves = {} - for key, value in pytree.items(): - new_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, dict): - named_leaves.update(get_named_leaves(value, new_key)) - else: - named_leaves[new_key] = value - return named_leaves - - named_leaves1 = get_named_leaves(tree1) - named_leaves2 = get_named_leaves(tree2) - - for key in named_leaves1: - # import pdb; pdb.set_trace() - if key not in named_leaves2: - print(f"Missing key in second tree: {key}") - return False - try: - if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol): - # print(f"Mismatch at leaf '{key}':\n{named_leaves1[key]}\n{named_leaves2[key]}") - # return False - # print(f"Mismatch at leaf '{key}'") - mismatch_values1 = named_leaves1[key].flatten()[:10] - mismatch_values2 = named_leaves2[key].flatten()[:10] - # print(f"Mismatch at leaf '{key}':\nFirst 10 elements:\n{mismatch_values1}\n{mismatch_values2}") - print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n") - for i in range(10): - print(f"{named_leaves1[key][..., i, :]}\n") - print(f"The second tensor:\n") - for i in range(10): - print(f"{named_leaves2[key][..., i, :]}\n") - return - except: - print(f"The issue is with {key}") - # print(f"Checking {key} done") - - print("All leaves match within tolerance.") - return True - -def test_huggingface_to_maxtext_back_to_huggingface_flow(): - base_num_query_heads = base_num_kv_heads = 16 - head_dim = 32 - # import pdb; pdb.set_trace() - wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(base_num_query_heads * head_dim , base_num_query_heads * head_dim ) - wq1 = wq.transpose() - wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) - - wq3 = permute_to_match_maxtext_rope(wq2) - stack_shape = (1,) - x = np.zeros(stack_shape + wq3.shape, dtype=np.float16) - x[0, ...] = wq3 - x = np.transpose(x, axes=(1, 0, 2, 3)) - - x = x[:, 0, :, :] - wq4 = unpermute_from_match_maxtext_rope(x, "llama3.1") - wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) - wq6 = wq5.transpose() - - if not np.array_equal(wq, wq6): - print("Test failed: wq does not match wq6") - - if not np.array_equal(wq1, wq5): - print("Test failed: wq1 does not match wq5") - - if not np.array_equal(wq2, wq4): - print("Test failed: wq2 does not match wq4") - - -if __name__ == "__main__": - hf_checkpoint_folder = "/mnt/disks/persist/checkpoints/huggingface/Llama3.1-8B" - hf_tensor = load_hf(hf_checkpoint_folder) - - meta_checkpoint_folder = "/tmp/hf_llama3_1_no_perm" - meta_tensor = load_hf(meta_checkpoint_folder) - - - compare_pytrees(hf_tensor, meta_tensor) - diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index f1cf483ee..bda38f4b6 100644 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -34,6 +34,7 @@ from maxtext_xpk_runner import XpkClusterConfig from maxtext_xpk_runner import LibTpuType + def add_xpk_runner_arguments(custom_parser: argparse.ArgumentParser): """Add arguments to the xpk runner parser. @@ -41,129 +42,131 @@ def add_xpk_runner_arguments(custom_parser: argparse.ArgumentParser): custom_parser: parser to add shared arguments to. """ custom_parser.add_argument( - '--project', + "--project", type=str, - default=None, required=True, + default=None, + required=True, help='GCE project name, defaults to "gcloud config project."', ) custom_parser.add_argument( - '--zone', + "--zone", type=str, - default=None, required=True, + default=None, + required=True, help=( 'GCE zone, e.g. us-central2-b, defaults to "gcloud config ' 'compute/zone." Only one of --zone or --region is allowed in a ' - 'command.' + "command." ), ) custom_parser.add_argument( - '--cluster_name', + "--cluster_name", type=str, - default=None, required=True, - help='cluster name The name of the cluster to run the job on. command.', + default=None, + required=True, + help="cluster name The name of the cluster to run the job on. command.", ) custom_parser.add_argument( - '--base_output_directory', + "--base_output_directory", type=str, - default=None, required=True, - help='gcloud bucket to store artifacts.', + default=None, + required=True, + help="gcloud bucket to store artifacts.", ) custom_parser.add_argument( - '--device_type', + "--device_type", type=str, - default=None, required=True, - help='tpu device type command.', + default=None, + required=True, + help="tpu device type command.", ) custom_parser.add_argument( - '--num_slices', + "--num_slices", type=int, - default='1', - help='Number of slices for tpu devices command.', + default="1", + help="Number of slices for tpu devices command.", ) custom_parser.add_argument( - '--model_name', + "--model_name", type=str, choices=list(trillium_model_dict.keys()) + list(v5e_model_dict.keys()), default=list(trillium_model_dict.keys())[0], - help=( - f'model to be benchmarked, supported models are the command choices.' - ), + help=(f"model to be benchmarked, supported models are the command choices."), ) custom_parser.add_argument( - '--libtpu_version', + "--libtpu_version", type=str, - default='20241009', - help='version of libtpu-nightly to be benchmarked command.', + default="20241009", + help="version of libtpu-nightly to be benchmarked command.", ) custom_parser.add_argument( - '--libtpu_type', + "--libtpu_type", type=str, choices=[t.value for t in LibTpuType], - default='nightly', - help='type of libtpu to be benchmarked command.', + default="nightly", + help="type of libtpu to be benchmarked command.", ) custom_parser.add_argument( - '--base_docker_image', + "--base_docker_image", type=str, - default='maxtext_base_image', - help='version of base docker image to be benchmarked command.', + default="maxtext_base_image", + help="version of base docker image to be benchmarked command.", ) custom_parser.add_argument( - '--pathways_server_image', + "--pathways_server_image", type=str, - default=( - 'us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest' - ), - help='version of pathways server image to be benchmarked command.', + default=("us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest"), + help="version of pathways server image to be benchmarked command.", ) custom_parser.add_argument( - '--pathways_proxy_server_image', + "--pathways_proxy_image", type=str, - default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest', - help='version of pathways proxy image to be benchmarked command.', + default="us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest", + help="version of pathways proxy image to be benchmarked command.", ) custom_parser.add_argument( - '--pathways_runner_image', + "--pathways_runner_image", type=str, - default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest', - help='version of pathways runner image to be benchmarked command.', + default="us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest", + help="version of pathways runner image to be benchmarked command.", ) custom_parser.add_argument( - '--remote_python_sidecar_image', + "--remote_python_sidecar_image", type=str, - help='version of remote python sidecar image to be benchmarked command.', + help="version of remote python sidecar image to be benchmarked command.", ) custom_parser.add_argument( - '--use_pathways', + "--use_pathways", type=bool, default=False, - help='whether to use pathways or not.', + help="whether to use pathways or not.", ) custom_parser.add_argument( - '--xpk_path', + "--xpk_path", type=str, - default='~/xpk', - help='path to xpk dir.', + default="~/xpk", + help="path to xpk dir.", ) custom_parser.add_argument( - '--priority', + "--priority", type=str, - default='medium', - help='Priority the XPK workload should run with.', + default="medium", + help="Priority the XPK workload should run with.", ) custom_parser.add_argument( - '--num_steps', + "--num_steps", type=int, default=20, - help='Number of steps to run the workload for.', + help="Number of steps to run the workload for.", ) custom_parser.add_argument( - '--max_restarts', + "--max_restarts", type=int, default=0, - help='Number of restarts to attempt.', + help="Number of restarts to attempt.", ) + def add_on_device_runner_arguments(custom_parser: argparse.ArgumentParser): """Add arguments to the on-device runner parser. @@ -171,50 +174,48 @@ def add_on_device_runner_arguments(custom_parser: argparse.ArgumentParser): custom_parser: parser to add shared arguments to. """ custom_parser.add_argument( - '--base_output_directory', + "--base_output_directory", type=str, - default=None, required=True, - help='gcloud bucket to store artifacts.', + default=None, + required=True, + help="gcloud bucket to store artifacts.", ) custom_parser.add_argument( - '--run_name', + "--run_name", type=str, default=None, - help='run_name for model run', + help="run_name for model run", ) custom_parser.add_argument( - '--model_name', + "--model_name", type=str, choices=list(trillium_model_dict.keys()) + list(v5e_model_dict.keys()), default=list(trillium_model_dict.keys())[0], - help=( - f'model to be benchmarked, supported models are the command choices.' - ), + help=(f"model to be benchmarked, supported models are the command choices."), ) custom_parser.add_argument( - '--libtpu_version', + "--libtpu_version", type=str, - default='20241009', - help='version of libtpu-nightly to be benchmarked command.', + default="20241009", + help="version of libtpu-nightly to be benchmarked command.", ) custom_parser.add_argument( - '--libtpu_type', + "--libtpu_type", type=str, choices=[t.value for t in LibTpuType], - default='nightly', - help='type of libtpu to be benchmarked command.', + default="nightly", + help="type of libtpu to be benchmarked command.", ) custom_parser.add_argument( - '--num_steps', + "--num_steps", type=int, default=20, - help='Number of steps to run the workload for.', + help="Number of steps to run the workload for.", ) + def main() -> None: - parser = argparse.ArgumentParser( - prog='benchmark runner', usage='%(prog)s [options]' - ) + parser = argparse.ArgumentParser(prog="benchmark runner", usage="%(prog)s [options]") subparsers = parser.add_subparsers(help="", dest="runner") xpk_runner_parser = subparsers.add_parser("xpk") on_device_runner_parser = subparsers.add_parser("on-device") @@ -223,8 +224,8 @@ def main() -> None: options = parser.parse_args() # Check that there are no duplicate model configs - duplicates = (trillium_model_dict.keys() & v5e_model_dict.keys()) - assert len(duplicates) == 0 , f'Found duplicate model config {duplicates}' + duplicates = trillium_model_dict.keys() & v5e_model_dict.keys() + assert len(duplicates) == 0, f"Found duplicate model config {duplicates}" model = trillium_model_dict.get(options.model_name) if model is None: @@ -241,60 +242,57 @@ def main() -> None: if options.runner == "xpk": cluster_config = XpkClusterConfig( - cluster_name=options.cluster_name, - project=options.project, - zone=options.zone, - device_type=options.device_type + cluster_name=options.cluster_name, project=options.project, zone=options.zone, device_type=options.device_type ) pw_config = None if options.use_pathways: pw_config = PathwaysConfig( - server_image=options.pathways_server_image, - proxy_server_image=options.pathways_proxy_server_image, - runner_image=options.pathways_runner_image, - remote_python_sidecar_image=options.remote_python_sidecar_image, + server_image=options.pathways_server_image, + proxy_image=options.pathways_proxy_image, + runner_image=options.pathways_runner_image, + remote_python_sidecar_image=options.remote_python_sidecar_image, ) workload_config = WorkloadConfig( - model=model, - num_slices=options.num_slices, - num_steps=options.num_steps, - device_type=options.device_type, - base_output_directory=options.base_output_directory, - priority=options.priority, - max_restarts=options.max_restarts, - libtpu_type=libtpu_type, - libtpu_nightly_version=options.libtpu_version, - base_docker_image=options.base_docker_image, - xpk_path=options.xpk_path, - pathways_config=pw_config + model=model, + num_slices=options.num_slices, + num_steps=options.num_steps, + device_type=options.device_type, + base_output_directory=options.base_output_directory, + priority=options.priority, + max_restarts=options.max_restarts, + libtpu_type=libtpu_type, + libtpu_nightly_version=options.libtpu_version, + base_docker_image=options.base_docker_image, + xpk_path=options.xpk_path, + pathways_config=pw_config, ) xpk_benchmark_runner(cluster_config, [workload_config]) elif options.runner == "on-device": # Generate a run_name if it is not passed from CLI or M_RUN_NAME env variable is empty - curr_date = time.strftime('%Y%m%d') + curr_date = time.strftime("%Y%m%d") if options.run_name is None: try: - run_name = os.environ['M_RUN_NAME'] + run_name = os.environ["M_RUN_NAME"] if run_name == "": - options.run_name = f'{options.model_name}-{curr_date}' + options.run_name = f"{options.model_name}-{curr_date}" except KeyError: - options.run_name = f'{options.model_name}-{curr_date}' + options.run_name = f"{options.model_name}-{curr_date}" workload_config = WorkloadConfig( - model=model, - num_slices=None, - device_type=None, - base_docker_image=None, - num_steps=options.num_steps, - base_output_directory=options.base_output_directory, - libtpu_type=libtpu_type, - libtpu_nightly_version=options.libtpu_version, - run_name=options.run_name + model=model, + num_slices=None, + device_type=None, + base_docker_image=None, + num_steps=options.num_steps, + base_output_directory=options.base_output_directory, + libtpu_type=libtpu_type, + libtpu_nightly_version=options.libtpu_version, + run_name=options.run_name, ) on_device_benchmark_runner(workload_configs=[workload_config]) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/llama2_v6e-256_benchmarks.py b/benchmarks/llama2_v6e-256_benchmarks.py index abbe891b6..a824904d4 100644 --- a/benchmarks/llama2_v6e-256_benchmarks.py +++ b/benchmarks/llama2_v6e-256_benchmarks.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import maxtext_trillium_model_configs as model_configs from maxtext_xpk_runner import BenchmarkRunner from maxtext_xpk_runner import HWConfig @@ -21,19 +22,17 @@ from maxtext_xpk_runner import XpkConfig -DATE = '20241009' -BASE_DOCKER_IMAGE = 'maxtext_base_image' +DATE = "20241009" +BASE_DOCKER_IMAGE = "maxtext_base_image" -ZONE = 'europe-west4' -PROJECT = 'tpu-prod-env-multipod' -CLUSTER_NAME = 'mlperf-v6e-256' -DEVICE_TYPE = 'v6e-256' +ZONE = "europe-west4" +PROJECT = "tpu-prod-env-multipod" +CLUSTER_NAME = "mlperf-v6e-256" +DEVICE_TYPE = "v6e-256" NUM_SLICES = 1 -BASE_OUTPUT_DIR = 'gs://maxtext-experiments-tpem/' +BASE_OUTPUT_DIR = "gs://maxtext-experiments-tpem/" -v6e_env_configs = SWconfig( - base_docker_image=BASE_DOCKER_IMAGE, libtpu_version=DATE -) +v6e_env_configs = SWconfig(base_docker_image=BASE_DOCKER_IMAGE, libtpu_version=DATE) v6e_256_configs = HWConfig(num_slices=NUM_SLICES, device_type=DEVICE_TYPE) llama2_70b_4096 = BenchmarkRunner( @@ -62,5 +61,5 @@ def main() -> None: xpk_benchmark_runner(cluster_config, [llama2_7b_4096, llama2_70b_4096]) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 9b4e7ba87..7bb7feb69 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + """Shared Benchmark config for v6e orchestrations.""" import dataclasses @@ -32,1043 +33,994 @@ class MaxTextModel: tuning_params: dict[str, typing.Any] xla_flags: str + trillium_model_dict = {} + # Run this for new definitions that should be part of the library. -def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_model: MaxTextModel)-> MaxTextModel: - model_dictionary[maxtext_model.model_name.replace('-', '_')] = maxtext_model +def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_model: MaxTextModel) -> MaxTextModel: + model_dictionary[maxtext_model.model_name.replace("-", "_")] = maxtext_model return maxtext_model + default_basic_1 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-basic-1", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "remat_policy": "full", - "global_parameter_scale": 1, - "attention": "flash", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags="", - ) + trillium_model_dict, + MaxTextModel( + model_name="default-basic-1", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "remat_policy": "full", + "global_parameter_scale": 1, + "attention": "flash", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags="", + ), ) default_basic_1_pw = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-basic-1-pw", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "remat_policy": "full", - "global_parameter_scale": 1, - "attention": "flash", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - # "profiler": "xplane", - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - # "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags="", - ) + trillium_model_dict, + MaxTextModel( + model_name="default-basic-1-pw", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "remat_policy": "full", + "global_parameter_scale": 1, + "attention": "flash", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + # "profiler": "xplane", + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + # "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags="", + ), ) default_32 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-32", - model_type="default", - tuning_params={ - "per_device_batch_size": 13, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 32, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 1024, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-32", + model_type="default", + tuning_params={ + "per_device_batch_size": 13, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 32, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 1024, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) default_64 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-64", - model_type="default", - tuning_params={ - "per_device_batch_size": 6, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 64, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-64", + model_type="default", + tuning_params={ + "per_device_batch_size": 6, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 64, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) default_128 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-128", - model_type="default", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 128, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-128", + model_type="default", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 128, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) # OOM, Not Optimized yet default_256 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-256", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": -1, - "dcn_fsdp_transpose_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 256, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-256", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": -1, + "dcn_fsdp_transpose_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 256, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) # OOM, Not Optimized yet default_512 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-512", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": -1, - # "dcn_fsdp_parallelism": 2, - "dcn_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 512, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-512", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": -1, + # "dcn_fsdp_parallelism": 2, + "dcn_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 512, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) gpt_3_175b = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gpt-3-175b", - model_type="gpt3-175b", - tuning_params={ - "per_device_batch_size": 3, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "attention": "flash", - "quantization": "int8", - "gcs_metrics": True, - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.DISABLE_BUNDLE_AWARE_COST_MODEL + trillium_model_dict, + MaxTextModel( + model_name="gpt-3-175b", + model_type="gpt3-175b", + tuning_params={ + "per_device_batch_size": 3, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "attention": "flash", + "quantization": "int8", + "gcs_metrics": True, + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.DISABLE_BUNDLE_AWARE_COST_MODEL + ), ), - ) ) llama2_7b_4096 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-7b-4096", - model_type="llama2-7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-7b-4096", + model_type="llama2-7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_7b_4096_pw = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-7b-4096-pw", - model_type="llama2-7b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - # "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-7b-4096-pw", + model_type="llama2-7b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + # "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_4096 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_4096_optimized = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_synthetic", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2_70b_4096_synthetic", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) # Enable SparseCore Offloading of AR in an optimized model. llama2_70b_4096_sc = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-sc", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-sc", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + ), ), - ) ) llama2_70b_4096_sc_real_data_tfds = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-sc", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-sc", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + ), ), - ) ) llama2_70b_4096_sc_real_data_grain = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "base_output_directory": ( - "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "base_output_directory": ("gs://trillium-storage-tests-nov24-sr/long-run-dec11"), + "enable_checkpointing": False, + "dataset_type": "grain", + "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", + "grain_worker_count": 24, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "profile_cleanly": False, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - "enable_checkpointing": False, - "dataset_type": "grain", - "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", - "grain_worker_count": 24, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "profile_cleanly": False, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - ) ) llama2_70b_4096_sc_real_data_grain_checkpoint = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "base_output_directory": ( - "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "base_output_directory": ("gs://trillium-storage-tests-nov24-sr/long-run-dec11"), + "checkpoint_period": 100, + "enable_checkpointing": True, + "async_checkpointing": True, + "dataset_type": "grain", + "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", + "grain_worker_count": 24, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - "checkpoint_period": 100, - "enable_checkpointing": True, - "async_checkpointing": True, - "dataset_type": "grain", - "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", - "grain_worker_count": 24, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - ) ) llama2_70b_4096_real_data_pw_long_run = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-rd-pw-lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "reuse_example_batch": 0, - "profiler": "xplane", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "tfds", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-rd-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "reuse_example_batch": 0, + "profiler": "xplane", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "tfds", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_4096_synthetic_pw_lr = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_synthetic_pw_lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - # "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2_70b_4096_synthetic_pw_lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + # "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_4096_pw_long_run = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-pw-lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-pw-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "steps": 1000000, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_4096_pw_rd_tfds = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_pw_rd_tfds", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2_70b_4096_pw_rd_tfds", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + # Additional tuning params for pathways long running test. + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "checkpoint_storage_use_ocdbt": False, + "checkpoint_storage_use_zarr3": False, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_pathways_goodput": True, + "enable_checkpoint_cloud_logger": True, + "enable_single_controller": True, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama3_8b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-8b-8192", - model_type="llama3-8b", - tuning_params={ - "per_device_batch_size": 8, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama3-8b-8192", + model_type="llama3-8b", + tuning_params={ + "per_device_batch_size": 8, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=(xla_flags_library.DENSE_VMEM_LIMIT_FLAG + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama3_70b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-70b-8192", - model_type="llama3-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "optimizer_memory_host_offload": True, - "gradient_clipping_threshold": 0, - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS - + " --xla_tpu_scheduler_percent_shared_memory_limit=90" + trillium_model_dict, + MaxTextModel( + model_name="llama3-70b-8192", + model_type="llama3-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "optimizer_memory_host_offload": True, + "gradient_clipping_threshold": 0, + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + + " --xla_tpu_scheduler_percent_shared_memory_limit=90" + ), ), - ) ) llama3_1_405b_8192_fsdp_dcn = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-1-405b-8192-fsdp-dcn", - model_type="llama3.1-405b", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": 64, - "ici_tensor_parallelism": 4, - "dcn_fsdp_parallelism": 2, - "allow_split_physical_axes": True, - "custom_mesh": "hybrid_ring_64x4", - "remat_policy": "custom", - "decoder_layer_input": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "out_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3-1-405b-8192-fsdp-dcn", + model_type="llama3.1-405b", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": 64, + "ici_tensor_parallelism": 4, + "dcn_fsdp_parallelism": 2, + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_64x4", + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "out_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) llama3_1_8b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-8b-8192", - model_type="llama3.1-8b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-8b-8192", + model_type="llama3.1-8b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) llama3_1_70b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-70b-8192", - model_type="llama3.1-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-70b-8192", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) llama3_1_70b_129024 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-70b-129024", - model_type="llama3.1-70b", - tuning_params={ - "per_device_batch_size": 0.125, - "ici_fsdp_parallelism": -1, - "ici_sequence_parallelism": 8, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 129024, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - "allow_split_physical_axes": True, - "custom_mesh": "hybrid_ring_32x8", - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-70b-129024", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.125, + "ici_fsdp_parallelism": -1, + "ici_sequence_parallelism": 8, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 129024, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_32x8", + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) mixtral_8x7b_dropless = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropless", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "full", - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": True, - "sparse_matmul": True, - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropless", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "full", + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": True, + "sparse_matmul": True, + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x7b_dropped = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropped", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": False, - "sparse_matmul": False, - "capacity_factor": 1.25, - "tokenizer_path": "assets/tokenizer.mistral-v1", - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropped", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": False, + "sparse_matmul": False, + "capacity_factor": 1.25, + "tokenizer_path": "assets/tokenizer.mistral-v1", + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x7b_dropped_int8 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropped_int8", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 8, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "full", - "attention": "flash", - "gcs_metrics": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": False, - "sparse_matmul": False, - "capacity_factor": 1.25, - "quantization": "int8", - "tokenizer_path": "assets/tokenizer.mistral-v1", - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropped_int8", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 8, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "full", + "attention": "flash", + "gcs_metrics": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": False, + "sparse_matmul": False, + "capacity_factor": 1.25, + "quantization": "int8", + "tokenizer_path": "assets/tokenizer.mistral-v1", + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x22b_dropped = _add_to_model_dictionary( @@ -1116,66 +1068,66 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_ ) gemma2_9b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gemma2-9b-8192", - model_type="gemma2-9b", - tuning_params={ - "per_device_batch_size": 3, - "ici_fsdp_transpose_parallelism": 256, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(114688) - + xla_flags_library.REDUCE_SCATTER_FUSION - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + trillium_model_dict, + MaxTextModel( + model_name="gemma2-9b-8192", + model_type="gemma2-9b", + tuning_params={ + "per_device_batch_size": 3, + "ici_fsdp_transpose_parallelism": 256, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(114688) + + xla_flags_library.REDUCE_SCATTER_FUSION + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + ), ), - ) ) gemma2_27b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gemma2-27b-8192", - model_type="gemma2-27b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_transpose_parallelism": 256, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(122880) - + xla_flags_library.REDUCE_SCATTER_FUSION - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + trillium_model_dict, + MaxTextModel( + model_name="gemma2-27b-8192", + model_type="gemma2-27b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_transpose_parallelism": 256, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(122880) + + xla_flags_library.REDUCE_SCATTER_FUSION + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + ), ), - ) ) diff --git a/benchmarks/maxtext_v5e_model_configs.py b/benchmarks/maxtext_v5e_model_configs.py index f9e2f5fb2..d7dbec19c 100644 --- a/benchmarks/maxtext_v5e_model_configs.py +++ b/benchmarks/maxtext_v5e_model_configs.py @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + """Shared Benchmark config for v5e orchestrations.""" import xla_flags_library @@ -19,205 +20,181 @@ v5e_model_dict = {} default_16b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="default-16b-v5e-256", - model_type="default", - tuning_params={ - "per_device_batch_size": 6, - "remat_policy": "full", - "global_parameter_scale": 16, - "max_target_length": 2048, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="default-16b-v5e-256", + model_type="default", + tuning_params={ + "per_device_batch_size": 6, + "remat_policy": "full", + "global_parameter_scale": 16, + "max_target_length": 2048, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) default_32b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="default-32b-v5e-256", - model_type="default", - tuning_params={ - "per_device_batch_size": 4, - "remat_policy": "full", - "global_parameter_scale": 32, - "max_target_length": 2048, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="default-32b-v5e-256", + model_type="default", + tuning_params={ + "per_device_batch_size": 4, + "remat_policy": "full", + "global_parameter_scale": 32, + "max_target_length": 2048, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) default_64b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="default-64b-v5e-256", - model_type="default", - tuning_params={ - "per_device_batch_size": 2, - "remat_policy": "full", - "global_parameter_scale": 64, - "max_target_length": 2048, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="default-64b-v5e-256", + model_type="default", + tuning_params={ + "per_device_batch_size": 2, + "remat_policy": "full", + "global_parameter_scale": 64, + "max_target_length": 2048, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) default_128b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="default-128b-v5e-256", - model_type="default", - tuning_params={ - "ici_fsdp_parallelism": -1, - "ici_tensor_parallelism": 16, - "per_device_batch_size": 1, - "remat_policy": "qkv_proj_offloaded", - "global_parameter_scale": 128, - "max_target_length": 2048, - "attention": "flash", - "use_iota_embed": True, - "fused_qkv": True, - "fused_mlp": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="default-128b-v5e-256", + model_type="default", + tuning_params={ + "ici_fsdp_parallelism": -1, + "ici_tensor_parallelism": 16, + "per_device_batch_size": 1, + "remat_policy": "qkv_proj_offloaded", + "global_parameter_scale": 128, + "max_target_length": 2048, + "attention": "flash", + "use_iota_embed": True, + "fused_qkv": True, + "fused_mlp": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) gpt_3_175b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="gpt-3-175b-v5e-256", - model_type="gpt3-175b", - tuning_params={ - "ici_fsdp_parallelism": -1, - "ici_tensor_parallelism": 16, - "per_device_batch_size": 0.5, - "remat_policy": "full", - "max_target_length": 2048, - "attention": "flash", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="gpt-3-175b-v5e-256", + model_type="gpt3-175b", + tuning_params={ + "ici_fsdp_parallelism": -1, + "ici_tensor_parallelism": 16, + "per_device_batch_size": 0.5, + "remat_policy": "full", + "max_target_length": 2048, + "attention": "flash", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_7b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="llama2-7b-v5e-256", - model_type="llama2-7b", - tuning_params={ - "ici_fsdp_parallelism": -1, - "per_device_batch_size": 4, - "remat_policy": "save_qkv_proj", - "max_target_length": 2048, - "use_iota_embed": True, - "tokenizer_path": "assets/tokenizer.llama2", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="llama2-7b-v5e-256", + model_type="llama2-7b", + tuning_params={ + "ici_fsdp_parallelism": -1, + "per_device_batch_size": 4, + "remat_policy": "save_qkv_proj", + "max_target_length": 2048, + "use_iota_embed": True, + "tokenizer_path": "assets/tokenizer.llama2", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_13b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="llama2-13b-v5e-256", - model_type="llama2-13b", - tuning_params={ - "ici_fsdp_parallelism": -1, - "per_device_batch_size": 8, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 2048, - "use_iota_embed": True, - "tokenizer_path": "assets/tokenizer.llama2", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="llama2-13b-v5e-256", + model_type="llama2-13b", + tuning_params={ + "ici_fsdp_parallelism": -1, + "per_device_batch_size": 8, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 2048, + "use_iota_embed": True, + "tokenizer_path": "assets/tokenizer.llama2", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) llama2_70b_v5e_256 = _add_to_model_dictionary( - v5e_model_dict, - MaxTextModel( - model_name="llama2-70b-v5e-256", - model_type="llama2-70b", - tuning_params={ - "ici_fsdp_parallelism": -1, - "per_device_batch_size": 2, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 2048, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags=( - xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER + v5e_model_dict, + MaxTextModel( + model_name="llama2-70b-v5e-256", + model_type="llama2-70b", + tuning_params={ + "ici_fsdp_parallelism": -1, + "per_device_batch_size": 2, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 2048, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + xla_flags=(xla_flags_library.DATA_PARALLEL_OVERLAP + xla_flags_library.CF_FOR_ALL_GATHER), ), - ) ) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index c22de9157..af94d4fc8 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -38,14 +38,15 @@ # Assumes you built maxtext dep image. # Assumes you have xpk installed in a git clone repo of ~/{wl_config.xpk_path}/xpk.py -_DEFAULT_MAXTEXT_BASE_DOCKER_IMAGE_NAME = 'maxtext_base_image' +_DEFAULT_MAXTEXT_BASE_DOCKER_IMAGE_NAME = "maxtext_base_image" + class LibTpuType(enum.Enum): - NIGHTLY = 'nightly-libtpu' + NIGHTLY = "nightly-libtpu" # In order to use a custom libtpu, put a libtpu.so file in your local # working directory. - CUSTOM = 'custom' - MAXTEXT = 'maxtext-docker' + CUSTOM = "custom" + MAXTEXT = "maxtext-docker" @dataclasses.dataclass @@ -70,11 +71,11 @@ class WorkloadConfig: base_output_directory: str base_docker_image: str libtpu_type: LibTpuType - libtpu_nightly_version: str = None # A date in %Y%M%D format, 20241201 + libtpu_nightly_version: str = None # A date in %Y%M%D format, 20241201 num_steps: int = 20 max_restarts: int = 0 priority: str = "medium" - xpk_path: str = '~/xpk' + xpk_path: str = "~/xpk" pathways_config: PathwaysConfig = None run_name: str = None @@ -112,12 +113,7 @@ def make_tmp_files(per_command_name): A list of temporary files for each command. """ # Supports removal of spaces from command names before converting to file name. - return [ - tempfile.NamedTemporaryFile( - delete=False, prefix=command.replace(' ', '-') + '-' - ) - for command in per_command_name - ] + return [tempfile.NamedTemporaryFile(delete=False, prefix=command.replace(" ", "-") + "-") for command in per_command_name] def run_commands(commands, jobname, per_command_name, batch=10, dry_run=False): @@ -137,17 +133,14 @@ def run_commands(commands, jobname, per_command_name, batch=10, dry_run=False): commands_batched = chunks(commands, batch) per_command_name_batches = chunks(per_command_name, batch) - print( - f'Breaking up a total of {len(commands)} commands into' - f' {len(commands_batched)} batches' - ) + print(f"Breaking up a total of {len(commands)} commands into" f" {len(commands_batched)} batches") if dry_run: - print('Pretending all the jobs succeeded') + print("Pretending all the jobs succeeded") return 0 max_return_code = 0 for i, _ in enumerate(commands_batched): - print(f'Dispatching batch {i}/{len(commands_batched)}') + print(f"Dispatching batch {i}/{len(commands_batched)}") batch_max_return_code, _ = run_command_batch( commands_batched[i], jobname, @@ -178,9 +171,7 @@ def run_command_batch(commands, jobname, per_command_name, output_logs): for i, command in enumerate(commands): children.append( # subprocess managed by list pylint: disable=consider-using-with - subprocess.Popen( - command, stdout=output_logs[i], stderr=output_logs[i], shell=True - ) + subprocess.Popen(command, stdout=output_logs[i], stderr=output_logs[i], shell=True) ) while True: @@ -192,25 +183,14 @@ def run_command_batch(commands, jobname, per_command_name, output_logs): if completed < total: slow_worker_index = returncodes.index(None) slow_worker_text = per_command_name[slow_worker_index] - slow_str = ( - f', task {slow_worker_text} still working, logfile' - f' {output_logs[slow_worker_index].name}' - ) + slow_str = f", task {slow_worker_text} still working, logfile" f" {output_logs[slow_worker_index].name}" else: - slow_str = '' - print( - f'[t={seconds_elapsed:.2f}, {jobname}] Completed' - f' {completed}/{total}{slow_str}' - ) + slow_str = "" + print(f"[t={seconds_elapsed:.2f}, {jobname}] Completed" f" {completed}/{total}{slow_str}") if max_returncode > 0: - failing_index = [ - i for i, x in enumerate(returncodes) if x is not None and x > 0 - ][0] - print(f'Terminating all {jobname} processes since at least one failed.') - print( - f'Failure is {per_command_name[failing_index]}' - f' and logfile {output_logs[failing_index].name}' - ) + failing_index = [i for i, x in enumerate(returncodes) if x is not None and x > 0][0] + print(f"Terminating all {jobname} processes since at least one failed.") + print(f"Failure is {per_command_name[failing_index]}" f" and logfile {output_logs[failing_index].name}") for child in children: child.terminate() break @@ -236,9 +216,7 @@ def run_command_with_updates(command, task, verbose=True) -> int: """ if verbose: - print( - f'Task: `{task}` is implemented by `{command}`, streaming output live.' - ) + print(f"Task: `{task}` is implemented by `{command}`, streaming output live.") with subprocess.Popen( command, stdout=sys.stdout, @@ -249,29 +227,23 @@ def run_command_with_updates(command, task, verbose=True) -> int: while True: return_code = child.poll() if return_code is None: - print(f'Waiting for `{task}`, for {i} seconds') + print(f"Waiting for `{task}`, for {i} seconds") time.sleep(1) i += 1 else: - print(f'Task: `{task}` terminated with code `{return_code}`') + print(f"Task: `{task}` terminated with code `{return_code}`") return return_code else: - print( - f'Task: `{task}` is implemented by `{command}`, hiding output unless' - ' there is an error.' - ) + print(f"Task: `{task}` is implemented by `{command}`, hiding output unless" " there is an error.") try: subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - print( - f'Task: `{task}` terminated with ERROR `{e.returncode}`, printing' - ' logs' - ) - print('*' * 80) + print(f"Task: `{task}` terminated with ERROR `{e.returncode}`, printing" " logs") + print("*" * 80) print(e.output) - print('*' * 80) + print("*" * 80) return e.returncode - print(f'Task: `{task}` succeeded.') + print(f"Task: `{task}` succeeded.") return 0 @@ -281,58 +253,60 @@ def build_user_command( ): is_pw_enabled = wl_config.pathways_config is not None - config_tuning_params = '' + config_tuning_params = "" for key, value in wl_config.model.tuning_params.items(): - config_tuning_params += f'{key}={value} ' + config_tuning_params += f"{key}={value} " - install_libtpu_cmd = '' + install_libtpu_cmd = "" jax_platforms = None vertex_tensorboard = '' # TODO() support modifying nightly / stable dependencies in pathway flow if is_pw_enabled: - jax_platforms = 'proxy' + jax_platforms = "proxy" else: if wl_config.libtpu_type == LibTpuType.NIGHTLY: install_libtpu_cmd += ( - f' pip install libtpu-nightly==0.1.dev{wl_config.libtpu_nightly_version} -f' - ' https://storage.googleapis.com/libtpu-releases/index.html &&' + f" pip install libtpu-nightly==0.1.dev{wl_config.libtpu_nightly_version} -f" + " https://storage.googleapis.com/libtpu-releases/index.html &&" ) elif wl_config.libtpu_type == LibTpuType.CUSTOM: # In order to use a custom libtpu, put a libtpu.so file in your local # working directory. - install_libtpu_cmd += ' mv libtpu.so /lib/ &&' + install_libtpu_cmd += " mv libtpu.so /lib/ &&" elif wl_config.libtpu_type == LibTpuType.MAXTEXT: # Use the libtpu dependent built in the docker image provided. - install_libtpu_cmd += '' + install_libtpu_cmd += "" - jax_platforms = 'tpu,cpu' + jax_platforms = "tpu,cpu" vertex_tensorboard = 'use_vertex_tensorboard=false vertex_tensorboard_project="" vertex_tensorboard_region=""' - assert jax_platforms is not None, 'Error in setting jax_platforms' + assert jax_platforms is not None, "Error in setting jax_platforms" - libtpu_flags = f'LIBTPU_INIT_ARGS=\'{wl_config.model.xla_flags}\'' + libtpu_flags = f"LIBTPU_INIT_ARGS='{wl_config.model.xla_flags}'" if name is None: - run_name_command="" + run_name_command = "" else: - run_name_command=f'run_name={name}' + run_name_command = f"run_name={name}" # Construct the command string with proper formatting and line continuations - command = ' '.join([ - f'{install_libtpu_cmd}', - f'echo {libtpu_flags} &&' if not is_pw_enabled else '', - f'export {libtpu_flags} &&' if not is_pw_enabled else '', - 'export ENABLE_PATHWAYS_PERSISTENCE=1 &&', - f'export JAX_PLATFORMS={jax_platforms} &&', - 'export ENABLE_PJRT_COMPATIBILITY=true &&', - 'python3 MaxText/train.py MaxText/configs/base.yml', - f'{config_tuning_params}', - f'steps={wl_config.num_steps}', - f'model_name={wl_config.model.model_type}', - f'base_output_directory={wl_config.base_output_directory}', - f'{vertex_tensorboard}', - f'{run_name_command}' - ]) + command = " ".join( + [ + f"{install_libtpu_cmd}", + f"echo {libtpu_flags} &&" if not is_pw_enabled else "", + f"export {libtpu_flags} &&" if not is_pw_enabled else "", + "export ENABLE_PATHWAYS_PERSISTENCE=1 &&", + f"export JAX_PLATFORMS={jax_platforms} &&", + "export ENABLE_PJRT_COMPATIBILITY=true &&", + "python3 MaxText/train.py MaxText/configs/base.yml", + f"{config_tuning_params}", + f"steps={wl_config.num_steps}", + f"model_name={wl_config.model.model_type}", + f"base_output_directory={wl_config.base_output_directory}", + f"{vertex_tensorboard}", + f"{run_name_command}", + ] + ) return command @@ -383,71 +357,66 @@ def generate_xpk_workload_cmd( time.localtime() length_of_random_str = 3 - temp_post_fix = ''.join( - random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str) - ) + temp_post_fix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str)) truncate_model_name = 12 truncate_prefix = 5 common_post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}" - common_prefix = os.environ['USER'] + common_prefix = os.environ["USER"] pw_prefix = "pw-" if is_pathways_enabled: - name = ( - f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}" - ) + name = f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}" else: - name = ( - f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}" - ) + name = f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}" name = f"{common_prefix[:truncate_prefix]}-{name}{common_post_fix}" - user_command = build_user_command( - name=name, - wl_config=wl_config - ) + user_command = build_user_command(name=name, wl_config=wl_config) - additional_flags = '' + additional_flags = "" if not is_pathways_enabled and wl_config.libtpu_type == LibTpuType.CUSTOM: additional_flags = '--env="TPU_LIBRARY_PATH=/lib/libtpu.so"' - docker_image_flag = '' + docker_image_flag = "" # pathways-related flags - pathways_specific_flags = '' - workload_create_command = f'python3 {wl_config.xpk_path}/xpk.py workload create' - device_type = f' --device-type={cluster_config.device_type}' + pathways_specific_flags = "" if is_pathways_enabled: pw_config = wl_config.pathways_config - device_type = f' --tpu-type={wl_config.device_type}' - workload_create_command = ( - f'python3 {wl_config.xpk_path}/xpk.py workload create-pathways' - ) - docker_image_flag = ( - f'--docker-image={pw_config.runner_image}' + pathways_specific_flags = ( + "--use-pathways" + f" --server-image={pw_config.server_image}" + f" --proxy-server-image={pw_config.proxy_image}" + f" --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}" + if pw_config.remote_python_sidecar_image is not None + else "" + " --termination-grace-period-seconds=300" + f" --pathways-gcs-location={wl_config.base_output_directory}" + f" --restart-on-user-code-failure" + f" --debug-dump-gcs={wl_config.base_output_directory}" ) + docker_image_flag = f"--docker-image={pw_config.runner_image}" else: docker_image_flag = f'--base-docker-image="{wl_config.base_docker_image}"' - print(f'User command: {user_command}') + print(f"User command: {user_command}") return ( ( - f'{workload_create_command}' - f' {_get_pathways_specific_flags(wl_config)}' - f' --cluster={cluster_config.cluster_name}' - f' --project={cluster_config.project}' - f' --zone={cluster_config.zone}' - f' {device_type}' - f' --num-slices={wl_config.num_slices}' + f"python3 {wl_config.xpk_path}/xpk.py workload create" + f" {pathways_specific_flags}" + f" --cluster={cluster_config.cluster_name}" + f" --project={cluster_config.project}" + f" --zone={cluster_config.zone}" + f" --device-type={cluster_config.device_type}" + f" --num-slices={wl_config.num_slices}" f' --command="{user_command}"' - f' {docker_image_flag}' - ' --enable-debug-logs' - f' --workload={name}' - f' --priority={wl_config.priority}' - f' --max-restarts={wl_config.max_restarts}' + f" {docker_image_flag}" + " --enable-debug-logs" + f" --workload={name}" + f" --priority={wl_config.priority}" + f" --max-restarts={wl_config.max_restarts}" # ' --use-vertex-tensorboard' # f' --experiment-name={test_purpose_name}' - f' {additional_flags}' + f" {additional_flags}" ), name, ) @@ -465,12 +434,11 @@ def run_xpk_workload( Returns: """ - assert cluster_config.device_type == wl_config.device_type, f"The workload device size {wl_config.device_type}, and cluster device size {cluster_config.device_type} don't match." - command, _ = generate_xpk_workload_cmd( - cluster_config=cluster_config, - wl_config=wl_config - ) - return run_command_with_updates(command, 'Run XPK workload') + assert ( + cluster_config.device_type == wl_config.device_type + ), f"The workload device size {wl_config.device_type}, and cluster device size {cluster_config.device_type} don't match." + command, _ = generate_xpk_workload_cmd(cluster_config=cluster_config, wl_config=wl_config) + return run_command_with_updates(command, "Run XPK workload") def xpk_benchmark_runner( @@ -480,10 +448,7 @@ def xpk_benchmark_runner( xpk_workload_names = [] xpk_workload_cmds = [] for wl_config in workload_configs: - command, name = generate_xpk_workload_cmd( - cluster_config=cluster_config, - wl_config=wl_config - ) + command, name = generate_xpk_workload_cmd(cluster_config=cluster_config, wl_config=wl_config) print(f"Name of the workload is: {name} \n") xpk_workload_names.append(name) @@ -495,46 +460,45 @@ def xpk_benchmark_runner( for xpk_workload_name, xpk_workload_cmd in zip(xpk_workload_names, xpk_workload_cmds): return_code = run_command_with_updates(xpk_workload_cmd, xpk_workload_name) if return_code != 0: - print('Unable to run xpk workload: {xpk_workload_name}') + print("Unable to run xpk workload: {xpk_workload_name}") + def on_device_benchmark_runner( workload_configs: list[WorkloadConfig], ): for wl_config in workload_configs: - user_command = build_user_command( - name=wl_config.run_name, - wl_config=wl_config - ) - print(f'User command: {user_command}') + user_command = build_user_command(name=wl_config.run_name, wl_config=wl_config) + print(f"User command: {user_command}") subprocess.run(user_command, shell=True, text=True) + # Run maxtext_xpk_runner.py as a script for executing multiple workloads pythonically! def main() -> int: # Variables to configure: - output_bucket = 'gs://DIR' + output_bucket = "gs://DIR" base_docker_image = _DEFAULT_MAXTEXT_BASE_DOCKER_IMAGE_NAME # Set up the clusters to run workloads on! v5e_cluster_config = XpkClusterConfig( - cluster_name='v5e-256', - project='my-cool-project', - zone='us-central2-b', - device_type='v5litepod-256', + cluster_name="v5e-256", + project="my-cool-project", + zone="us-central2-b", + device_type="v5litepod-256", ) v6e_cluster_config = XpkClusterConfig( - cluster_name='v6e-256', - project='my-cool-project', - zone='us-central2-b', - device_type='v6e-256', + cluster_name="v6e-256", + project="my-cool-project", + zone="us-central2-b", + device_type="v6e-256", ) xpk_workload_cmds = [] xpk_workload_names = [] list_of_models = [ - model_configs.llama2_70b_4096_sc, - # model_configs.default_128 + model_configs.llama2_70b_4096_sc, + # model_configs.default_128 ] # Loop possibilities: @@ -551,19 +515,21 @@ def main() -> int: # 3. See other examples below - user = os.environ['USER'] - base_output_dir = os.path.join(output_bucket,user) + user = os.environ["USER"] + base_output_dir = os.path.join(output_bucket, user) for model in list_of_models: # Run workloads on the below clusters for cluster_config in [ - # v5e_cluster_config, - # v6e_cluster_config, - v6e_cluster_config_yucmhab, - # another_config, + # v5e_cluster_config, + # v6e_cluster_config, + v6e_cluster_config_yucmhab, + # another_config, ]: # Run workloads in the following slice configurations - for num_slices in [1,]: + for num_slices in [ + 1, + ]: # Use the libtpu dependencies from: for libtpu_type in [ # LibTpuType.CUSTOM @@ -571,21 +537,18 @@ def main() -> int: # LibTpuType.NIGHTLY ]: wl_config = WorkloadConfig( - model=model, - num_slices=num_slices, - device_type=cluster_config.device_type, - base_output_directory=base_output_dir, - priority="medium", - max_restarts=0, - libtpu_type=libtpu_type, - libtpu_nightly_version="", - base_docker_image=base_docker_image, - pathways_config=None - ) - command, name = generate_xpk_workload_cmd( - cluster_config=cluster_config, - wl_config=wl_config + model=model, + num_slices=num_slices, + device_type=cluster_config.device_type, + base_output_directory=base_output_dir, + priority="medium", + max_restarts=0, + libtpu_type=libtpu_type, + libtpu_nightly_version="", + base_docker_image=base_docker_image, + pathways_config=None, ) + command, name = generate_xpk_workload_cmd(cluster_config=cluster_config, wl_config=wl_config) print(f"Name of the workload is: {name} \n") xpk_workload_names.append(name) @@ -596,7 +559,7 @@ def main() -> int: for xpk_workload_name, xpk_workload_cmd in zip(xpk_workload_names, xpk_workload_cmds): return_code = run_command_with_updates(xpk_workload_cmd, xpk_workload_name) if return_code != 0: - print('Unable to run xpk workload: {xpk_workload_name}') + print("Unable to run xpk workload: {xpk_workload_name}") # Support Batch workloads one day. Note that this doesn't show the xpk logs per workload. # They are saved to file instead. @@ -607,9 +570,10 @@ def main() -> int: # batch=1, # Parallel execution of workloads is not supported in XPK yet. # dry_run=False, # ) - # print(f'Return_codes: {return_codes}') +# print(f'Return_codes: {return_codes}') + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/recipes/__init__.py b/benchmarks/recipes/__init__.py index 8ad88c258..bac559504 100644 --- a/benchmarks/recipes/__init__.py +++ b/benchmarks/recipes/__init__.py @@ -11,4 +11,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -""" \ No newline at end of file +""" diff --git a/benchmarks/recipes/args_helper.py b/benchmarks/recipes/args_helper.py index 33b819994..7a2438927 100644 --- a/benchmarks/recipes/args_helper.py +++ b/benchmarks/recipes/args_helper.py @@ -27,9 +27,7 @@ DELETE = "delete" -def _handle_delete( - cluster_config: mxr.XpkClusterConfig, user: str, **kwargs -) -> int: +def _handle_delete(cluster_config: mxr.XpkClusterConfig, user: str, **kwargs) -> int: """Handles the deletion of workloads. Args: @@ -44,16 +42,11 @@ def _handle_delete( f"--project={cluster_config.project} --cluster={cluster_config.cluster_name}" f" --filter-by-job={first_five_chars} --zone={cluster_config.zone}" ) - print( - f"Deleting workloads starting with: {first_five_chars} using command:" - f" {delete_command}" - ) + print(f"Deleting workloads starting with: {first_five_chars} using command:" f" {delete_command}") os.system(delete_command) -def handle_cmd_args( - cluster_config: mxr.XpkClusterConfig, *actions: str, **kwargs -) -> bool: +def handle_cmd_args(cluster_config: mxr.XpkClusterConfig, *actions: str, **kwargs) -> bool: """Parses command-line arguments and executes the specified actions. Args: diff --git a/benchmarks/recipes/pw_remote_python_recipe.py b/benchmarks/recipes/pw_remote_python_recipe.py index c0a98acd2..0d2d64aa8 100644 --- a/benchmarks/recipes/pw_remote_python_recipe.py +++ b/benchmarks/recipes/pw_remote_python_recipe.py @@ -36,9 +36,7 @@ def main() -> int: xpk_path = "xpk" # Handle command line arguments using args_helper - should_continue = helper.handle_cmd_args( - cluster_config, helper.DELETE, xpk_path=xpk_path - ) + should_continue = helper.handle_cmd_args(cluster_config, helper.DELETE, xpk_path=xpk_path) if not should_continue: return 0 @@ -46,14 +44,8 @@ def main() -> int: # Configure test images user = os.environ["USER"] region = "-".join(cluster_config.zone.split("-")[:-1]) - proxy_image = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/" - "proxy_server:latest" - ) - server_image = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/" - "server:latest" - ) + proxy_image = f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/" "proxy_server:latest" + server_image = f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/" "server:latest" remote_python_image = f"gcr.io/{cluster_config.project}/{user}/remote_python_sidecar_latest:latest" runner = f"gcr.io/{cluster_config.project}/{user}_latest:latest" base_output_directory = f"gs://{user}-{region}/{user}" @@ -92,9 +84,7 @@ def main() -> int: xpk_path=xpk_path, num_steps=1000000, ) - command, name = mxr.generate_xpk_workload_cmd( - cluster_config=cluster_config, wl_config=wl_config - ) + command, name = mxr.generate_xpk_workload_cmd(cluster_config=cluster_config, wl_config=wl_config) print(f"Name of the workload is: {name} \n") xpk_workload_names.append(name) @@ -102,12 +92,8 @@ def main() -> int: print(f"XPK command to be used is: {command} \n") xpk_workload_cmds.append(command) - for xpk_workload_name, xpk_workload_cmd in zip( - xpk_workload_names, xpk_workload_cmds - ): - return_code = mxr.run_command_with_updates( - xpk_workload_cmd, xpk_workload_name - ) + for xpk_workload_name, xpk_workload_cmd in zip(xpk_workload_names, xpk_workload_cmds): + return_code = mxr.run_command_with_updates(xpk_workload_cmd, xpk_workload_name) if return_code != 0: print(f"Unable to run xpk workload: {xpk_workload_name}") diff --git a/benchmarks/xla_flags_library.py b/benchmarks/xla_flags_library.py index d3ca4d32e..6992cf23c 100644 --- a/benchmarks/xla_flags_library.py +++ b/benchmarks/xla_flags_library.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + """This file contains commonly-used XLA Flags.""" ### NOTICE ### @@ -29,9 +30,7 @@ DENSE_VMEM_LIMIT_FLAG = f" --xla_tpu_scoped_vmem_limit_kib={_DENSE_VMEM_LIMIT}" MOE_VMEM_LIMIT_FLAG = f" --xla_tpu_scoped_vmem_limit_kib={_MOE_VMEM_LIMIT}" -CUSTOM_VMEM_LIMIT_FLAG = ( - lambda vmem_limit: f"--xla_tpu_scoped_vmem_limit_kib={vmem_limit}" -) +CUSTOM_VMEM_LIMIT_FLAG = lambda vmem_limit: f"--xla_tpu_scoped_vmem_limit_kib={vmem_limit}" # Continuation Fusion (CF) for All Gather Collectives # Continuation Fusion is a form of parallelizing compute work with collectives. @@ -87,11 +86,9 @@ " --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false" " --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false" " --xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter=false" - " --xla_tpu_enable_sparse_core_collective_offload_all_gather=true" " --xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true" " --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true" - " --xla_tpu_enable_all_gather_offload_tracing=true" " --xla_tpu_enable_reduce_scatter_offload_tracing=true" " --xla_tpu_enable_all_reduce_offload_tracing=true" @@ -148,8 +145,7 @@ # xla_tpu_enable_data_parallel_all_reduce_opt: # optimize DCN all-reduces used for data parallel sharding DATA_PARALLEL_OVERLAP = ( - " --xla_tpu_enable_data_parallel_all_reduce_opt=true" - " --xla_tpu_data_parallel_opt_different_sized_ops=true" + " --xla_tpu_enable_data_parallel_all_reduce_opt=true" " --xla_tpu_data_parallel_opt_different_sized_ops=true" ) # Host offloading Flags. These are optimizations recommended when using host @@ -169,15 +165,11 @@ ) # Flags to optimize pipeline parallelism over DCN with large host offloading. -PIPELINING_FLAGS = ( - " --xla_tpu_iova_dma_chunk_size_bytes=16777216" # breaks DMA to/from host into 16M chunks -) +PIPELINING_FLAGS = " --xla_tpu_iova_dma_chunk_size_bytes=16777216" # breaks DMA to/from host into 16M chunks # Disable bundle-aware CostModel which was causing worse perf b/357103386. # Some fusions in the backward pass of the model were 3x slower without this. -DISABLE_BUNDLE_AWARE_COST_MODEL = ( - " --xla_tpu_use_bundle_aware_cost_model_for_fusions=false" -) +DISABLE_BUNDLE_AWARE_COST_MODEL = " --xla_tpu_use_bundle_aware_cost_model_for_fusions=false" # Enable Silent Data Corruption (SDC) Checker # SDC Checker will check for chip / ici / hardware corruption events. diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index f719ca863..e2dd3eef4 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -25,7 +25,7 @@ def get_last_n_data(metrics_file, target, n=10): last_n_data = [] - with open(metrics_file, 'r', encoding='utf8') as file: + with open(metrics_file, "r", encoding="utf8") as file: lines = file.readlines() for line in lines[::-1]: metrics = json.loads(line) @@ -34,98 +34,105 @@ def get_last_n_data(metrics_file, target, n=10): if len(last_n_data) >= n: break return last_n_data - + def assert_metric_average(metrics_file, threshold, target): last_n_data = get_last_n_data(metrics_file, target) avg_value = sum(last_n_data) / len(last_n_data) # Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs) # is greater than the threshold. - print(f'avg value of target {target} is {avg_value}') + print(f"avg value of target {target} is {avg_value}") assert avg_value >= float(threshold) - print('assert metric average passed.') + print("assert metric average passed.") + def test_final_loss(metrics_file, target_loss): target_loss = float(target_loss) - with open(metrics_file, 'r', encoding='utf8') as metrics: + with open(metrics_file, "r", encoding="utf8") as metrics: use_last_n_data = 10 - last_n_data = get_last_n_data(metrics_file, 'learning/loss', use_last_n_data) + last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data) avg_last_n_data = sum(last_n_data) / len(last_n_data) print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}") print(f"Target loss is {target_loss}") assert avg_last_n_data < target_loss - print('Final loss test passed.') + print("Final loss test passed.") + def test_checkpointing(metrics_file, target, dataset_type): """Asserts over loss values from loaded checkpoint""" - metrics_file_saved = 'saved_' + metrics_file - metrics_file_restored = 'restored_' + metrics_file + metrics_file_saved = "saved_" + metrics_file + metrics_file_restored = "restored_" + metrics_file - with open(metrics_file_saved, 'r', encoding='utf8') as saved,\ - open(metrics_file_restored, 'r', encoding='utf8') as restored: + with ( + open(metrics_file_saved, "r", encoding="utf8") as saved, + open(metrics_file_restored, "r", encoding="utf8") as restored, + ): saved_loss = json.loads(saved.readlines()[-1])[target] restored_loss = json.loads(restored.readlines()[0])[target] # Checks that checkpoint restore was successful by comparing loss of last # step in saved checkpoint to loss of first step in restored checkpoint print("saved loss: ", saved_loss) print("restored loss: ", restored_loss) - if dataset_type=='tfds': + if dataset_type == "tfds": assert isclose(saved_loss, restored_loss, rel_tol=0.1) - elif dataset_type=='grain': - assert saved_loss==restored_loss + elif dataset_type == "grain": + assert saved_loss == restored_loss else: raise ValueError(f"Unknown dataset_type {dataset_type}. dataset_type must be tfds or grain") - print('checkpointing test passed.') + print("checkpointing test passed.") + def test_determinism(metrics_file, target): """Asserts over loss values from two runs""" - run_1 = 'run_1_' + metrics_file - run_2 = 'run_2_' + metrics_file + run_1 = "run_1_" + metrics_file + run_2 = "run_2_" + metrics_file - with open(run_1, 'r', encoding='utf8') as run_1_file,\ - open(run_2, 'r', encoding='utf8') as run_2_file: + with open(run_1, "r", encoding="utf8") as run_1_file, open(run_2, "r", encoding="utf8") as run_2_file: run_1_loss = json.loads(run_1_file.readlines()[-1])[target] run_2_loss = json.loads(run_2_file.readlines()[-1])[target] # Check that the two runs have the same loss print(f"Run 1 loss:{run_1_loss}", flush=True) print(f"Run 2 loss:{run_2_loss}", flush=True) - assert run_1_loss==run_2_loss - print('determinism test passed.') + assert run_1_loss == run_2_loss + print("determinism test passed.") + def test_vocab_creation(target): bucket_name = target.split("/")[2] vocab_path = "/".join(target.split("/")[3:]) storage_client = storage.Client() assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client) - print('vocab creation test passed.') + print("vocab creation test passed.") + def test_start_step(metrics_file, start_step_target): - with open(metrics_file, 'r', encoding='utf8') as metrics: + with open(metrics_file, "r", encoding="utf8") as metrics: start_step = json.loads(metrics.readlines()[0])["step"] print(f"Start step is {start_step}, start step target is {start_step_target}") - assert start_step==float(start_step_target) + assert start_step == float(start_step_target) print("Start step test passed.") + def main(argv: Sequence[str]) -> None: _, test_scenario, *test_vars = argv - if test_scenario == 'metrics_average': + if test_scenario == "metrics_average": assert_metric_average(*test_vars) - elif test_scenario == 'checkpoint_save_restore': - test_checkpointing(*test_vars, dataset_type='tfds') - elif test_scenario == 'grain_checkpoint_save_restore': - test_checkpointing(*test_vars, dataset_type='grain') - elif test_scenario == 'determinism': + elif test_scenario == "checkpoint_save_restore": + test_checkpointing(*test_vars, dataset_type="tfds") + elif test_scenario == "grain_checkpoint_save_restore": + test_checkpointing(*test_vars, dataset_type="grain") + elif test_scenario == "determinism": test_determinism(*test_vars) - elif test_scenario == 'vocab_creation': + elif test_scenario == "vocab_creation": test_vocab_creation(*test_vars) - elif test_scenario == 'final_loss': + elif test_scenario == "final_loss": test_final_loss(*test_vars) - elif test_scenario == 'test_start_step': + elif test_scenario == "test_start_step": test_start_step(*test_vars) else: - raise ValueError(f"Unrecognized test_scenario {test_scenario}") + raise ValueError(f"Unrecognized test_scenario {test_scenario}") if __name__ == "__main__": diff --git a/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh b/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh index 08cab692f..4a30b9ec6 100644 --- a/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh +++ b/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh @@ -3,7 +3,7 @@ # huggingface-cli download meta-llama/Llama-3.1-8B-Instruct --local-dir $CHECKPOINT_ORIGINAL # Or download the DeepSeek llama 8B model -export CHECKPOINT_ORIGINAL=/mnt/disks/persist/checkpoints/huggingface/Llama3.1-8B-Instruct +export CHECKPOINT_ORIGINAL=/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local-dir $CHECKPOINT_ORIGINAL export CHECKPOINT_TPU_SCANNED=$CHECKPOINT_ORIGINAL/scanned_chkpt @@ -12,29 +12,41 @@ export TOKENIZER=assets/tokenizer_llama3.tiktoken export BASE_OUTPUT_PATH=$CHECKPOINT_ORIGINAL export RUN_NAME=unscanned_chkpt export CHECKPOINT_TPU_UNSCANNED=$BASE_OUTPUT_PATH/$RUN_NAME/checkpoints/0/items +export CHECKPOINT_TPU_CONVERTED_BACK=${CHECKPOINT_ORIGINAL}/converted_back +export MODEL_SIZE=llama3.1-8b -JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=$CHECKPOINT_ORIGINAL --model-size=llama3.1-8b --maxtext-model-path=$CHECKPOINT_TPU_SCANNED --huggingface-checkpoint=true +# Remove previous checkpoints to have a clean start +rm $CHECKPOINT_ORIGINAL/scanned_chkpt $CHECKPOINT_ORIGINAL/unscanned_chkpt ${CHECKPOINT_ORIGINAL}/converted_back + +# Convert the checkpoints +JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=$CHECKPOINT_ORIGINAL --model-size=$MODEL_SIZE --maxtext-model-path=$CHECKPOINT_TPU_SCANNED --huggingface-checkpoint=true # Let's verify the generated scanned checkpoint to see if it matches with Huggingface golden logits python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 # If not, we can convert the checkpoint back from MaxText to Huggingface and compare with the orignal one -JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=/tmp/hf_llama3_1 -python3 MaxText load_checkpoint.py +JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=$CHECKPOINT_TPU_CONVERTED_BACK + +python3 MaxText/tests/hf_checkpoint_conversion_test.py --original_ckpt=${CHECKPOINT_ORIGINAL} --converted_cckpt=$CHECKPOINT_TPU_CONVERTED_BACK # If everything looks good, we move on to convert to the unrolled checkpoint for performant serving JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true +# Let's verify the generated unscanned checkpoint to see if it matches with Huggingface golden logits +python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 JAX_PLATFORMS=tpu python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=$TOKENIZER run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED -# Output from huggingface 8B model on MaxText: +# You can also check the results from scanned version: +JAX_PLATFORMS=tpu python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=$TOKENIZER run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items + +##### Output from huggingface llama 8B Instruct checkpoint on MaxText: #Input `I love to` -> ` travel and explore new places, but I also love to stay at home and relax. I'm a bit of a homebody, and I enjoy spending time with my family and friends. I'm a bit of a foodie, and I love trying new recipes and experimenting with different flavors and ingredients. I'm also a bit of a movie buff, and I love watching classic films and new releases alike. # I'm a bit of a hopeless romantic, and I believe in the idea of true love. I'm looking for someone who shares my values and my sense of humor, and who is always up for a good time. I'm a bit of a goofball, and I love to laugh and have fun. I'm looking for someone who can keep up with me and appreciate my quirks. # I'm a bit of a creative person, and I love to express myself through art, music, and writing. I'm a bit of a dreamer, and I love to imagine new possibilities and scenarios. I'm a bit of a perfectionist, and I always strive to do my best and be my best self. # ... -# Output from huggingface DeekSeek 8B model on MaxText: +##### Output from huggingface DeekSeek distilled llama 8B checkpoint on MaxText: # Input `I love to` -> ` write, but I'm not sure how to start a blog. I have some ideas, but I don't know where to begin. Maybe I should just start writing and see where it goes. But I don't want to end up with just a few posts and then stop. How do I keep going? # I also wonder if I should focus on a specific niche or write about a variety of topics. I like cooking, photography, and personal development. It's hard to decide which one to focus on. Maybe I can combine them somehow? @@ -161,3 +173,6 @@ JAX_PLATFORMS=tpu python MaxText/decode.py MaxText/configs/base.yml tokenizer_pa # ### 1. **Choose a Niche** # - **Start Small:** Begin with a specific niche that aligns with your interests, such as cooking, photography, or personal development. # - **Combine Interests:** If you're unsure, consider blending your passions, like a food blog with photography tips or a personal development blog with` + +##### Output from the original Huggingface model 8B on GPU: +# [{'generated_text': "I love to cook, but my problem is that I don't have enough space in my kitchen. I'm trying to figure out how to make the most of the limited space I have. I have a small kitchen, so I need to be strategic about how I use the existing space. Let me think about the different areas in my kitchen and how I can optimize them.\n\nFirst, my countertops are pretty cluttered. I have appliances like my coffee maker, toaster, and blender all out in the open. I wonder if I can find a better way to store these so they don't take up as much space. Maybe I can get some sort of organizers or racks to keep them in one place.\n\nThen, my cabinets are mostly empty, but I don't have enough storage for all my spices and utensils. I have a lot of spices in those small glass containers, and they take up a lot of space. Maybe I can get some spice racks or use some of those small containers to organize them better. Also, my utensils are all jumbled up in a drawer; I should probably get some dividers or a drawer organizer to keep them separate.\n\nI have a small pantry, but it's not very organized. I have canned goods and dry goods all mixed together. I think I need some kind of shelf or bin to separate them into categories like baking supplies, snacks, and pantry staples. That way, I can find what I need without digging through everything.\n\nThe sink area is another issue. I always have dishes piling up next to the sink. I should get some dish racks or maybe a small trash can with a lid to keep things tidy. Also, there's a lot of water spots on the counter; maybe a wet wipe container would help keep things clean.\n\nI have a small dining area that's also cluttered. I have a small table and chairs, but there's always stuff laying on the table. I should probably get some storage benches or a shelf to put things like mail or coats out of the way.\n\nIn the corner where the stove is, there's not much space. I have a pan rack there, but it's not enough. I might need a larger rack or maybe some hanging storage solutions to maximize the vertical space.\n\nThe floor is another area I can utilize. I have some extra boxes under the sink that I can maybe move elsewhere or get some stacking bins to store things like seasonal items or extra kitchen gadgets.\n\nI also have a lot of small kitchen appliances that I don't use often. Maybe I can find a place to store them more efficiently, like on the top shelf of the pantry or using some vertical storage.\n\nI think I need to focus on each area one by one and figure out the best way to organize them. Maybe start with the countertops by getting some organizers for the appliances. Then move on to the cabinets and pantry, organizing spices and utensils. After that, work on the sink area and dining area. Finally, look into the vertical spaces and floor storage to maximize the space.\n\nI should also think about multi-functional furniture, like a bench that can hold some storage or a kitchen cart that can help move things around without taking up too much space. Maybe using the walls for more storage with some hooks or shelves.\n\nIt's a bit overwhelming, but if I take it step by step, I can definitely make a big difference in my kitchen. I'll start by decluttering and then figure out the best storage solutions for each area.\n**Step-by-Step Explanation and Answer:**\n\nTo optimize your small kitchen, follow this organized approach:\n\n1. **Countertops:**\n - **Clear Clutter:** Remove appliances like the coffee maker, toaster, and blender.\n - **Storage Solutions:**\n - Use a wall-mounted organizer or a shelf to keep these appliances tidy.\n - Opt for a slim drawer or tray for small items like cutlery or utensils.\n\n2. **Cabinets:**\n - **Spices:** Install a spice rack or use small, labeled containers for easy access.\n - **Utensils:** Utilize drawer dividers or a kitchen utensil organizer to keep them sorted.\n\n3. **Pantry:**\n - **Categorize Foods:** Use labeled bins or shelves for baking supplies, snacks, and staples.\n - **Vertical Storage:** Install shelving or use hanging storage for frequently used items.\n\n4. **Sink Area:**\n - **Dish Storage:** Place a dish rack near the sink for drying utensils or dishes.\n - **Trash and Cleaning Supplies:** Keep a small trash can with a lid and store cleaning products on a shelf or under the sink.\n\n5. **Dining Area:**\n - **Storage Solutions:**\n - Use a small bench with storage for items like mail or coats.\n - Install a shelf above the table for decor or extra storage.\n\n6. **"}] diff --git a/multihost_job.py b/multihost_job.py index 04be91729..81f7f7454 100644 --- a/multihost_job.py +++ b/multihost_job.py @@ -47,33 +47,36 @@ import shutil - def get_project(): completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) - project_outputs = completed_command.stdout.decode().strip().split('\n') - if len(project_outputs) < 1 or project_outputs[-1]=='': + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project '") - return project_outputs[-1] # The project name lives on the last line of the output + return project_outputs[-1] # The project name lives on the last line of the output + def get_zone(): completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) - zone_outputs = completed_command.stdout.decode().strip().split('\n') - if len(zone_outputs) < 1 or zone_outputs[-1]=='': + zone_outputs = completed_command.stdout.decode().strip().split("\n") + if len(zone_outputs) < 1 or zone_outputs[-1] == "": sys.exit("You must specify the zone in the ZONE flag or set it with 'gcloud config set compute/zone '") - return zone_outputs[-1] # The zone name lives on the last line of the output + return zone_outputs[-1] # The zone name lives on the last line of the output + def get_run_name(): now = datetime.now() return os.getlogin() + "-" + now.strftime("%Y-%m-%d-%H-%M-%S") + def normalize_gcs_bucket_name(args): - """ Remove the gs:// from bucket_name if passed.""" - if len(args.BUCKET_NAME) > 5 and args.BUCKET_NAME[0:5]=="gs://": - args.BUCKET_NAME=args.BUCKET_NAME[5:] + """Remove the gs:// from bucket_name if passed.""" + if len(args.BUCKET_NAME) > 5 and args.BUCKET_NAME[0:5] == "gs://": + args.BUCKET_NAME = args.BUCKET_NAME[5:] return args + def print_flags(args): - """ Print configuration values after defaults have been filled in. """ + """Print configuration values after defaults have been filled in.""" print("Running multihost_job with the following configuration:") print(f"Project (--PROJECT) = {args.PROJECT}") print(f"Zone (--ZONE) = {args.ZONE}") @@ -88,10 +91,11 @@ def print_flags(args): print(f"Command to run (--COMMAND) = {args.COMMAND}") print(f"Enable Autocheckpoint (--ENABLE_AUTOCHECKPOINT) = {args.ENABLE_AUTOCHECKPOINT}\n") + def move_script_dir_to_gcs(script_dir, tmp_dir, zip_name, bucket_path): - """ Zip the script directory, cp it to GCS """ + """Zip the script directory, cp it to GCS""" original_working_directory = os.getcwd() - os.chdir(script_dir) # To tar script_dir, it is most convenient to cd there. + os.chdir(script_dir) # To tar script_dir, it is most convenient to cd there. # Zip script directory, storing it in the logging directory. os.makedirs(tmp_dir, exist_ok=True) @@ -109,28 +113,30 @@ def move_script_dir_to_gcs(script_dir, tmp_dir, zip_name, bucket_path): return captured_output + def run_create_resources(startup_script_file, args): - """ Run the Create Queued Resources (CQR) request """ + """Run the Create Queued Resources (CQR) request""" # pylint: disable=line-too-long - command = fr'gcloud alpha compute tpus queued-resources create {args.RUN_NAME} --accelerator-type={args.TPU_TYPE} --runtime-version={args.VERSION} --project={args.PROJECT} --zone={args.ZONE}' + command = rf"gcloud alpha compute tpus queued-resources create {args.RUN_NAME} --accelerator-type={args.TPU_TYPE} --runtime-version={args.VERSION} --project={args.PROJECT} --zone={args.ZONE}" if args.NUM_SLICES > 1: - command = command + f' --node-prefix={args.RUN_NAME} --node-count={args.NUM_SLICES}' + command = command + f" --node-prefix={args.RUN_NAME} --node-count={args.NUM_SLICES}" else: - command = command + f' --node-id={args.RUN_NAME}' + command = command + f" --node-id={args.RUN_NAME}" if args.CQR_EXTRA_ARGS: - command = command + ' ' + args.CQR_EXTRA_ARGS + command = command + " " + args.CQR_EXTRA_ARGS if args.ENABLE_AUTOCHECKPOINT: - command = command + ' --autocheckpoint-enabled' + command = command + " --autocheckpoint-enabled" - command = command + f' --metadata-from-file=startup-script={startup_script_file}' + command = command + f" --metadata-from-file=startup-script={startup_script_file}" captured_output = subprocess.run(command, check=False, shell=True, capture_output=True) return captured_output + def write_startup_script(zip_gcs_path, zip_name, log_name, bucket_path, startup_script_file, args): - """ Write the startup script locally into a file to be passed to the CQR command. """ + """Write the startup script locally into a file to be passed to the CQR command.""" startup_script = f"""#!/bin/bash mkdir -p {args.RUN_NAME} cd {args.RUN_NAME} @@ -150,8 +156,9 @@ def write_startup_script(zip_gcs_path, zip_name, log_name, bucket_path, startup_ f.write(startup_script) return startup_script + def get_env_command_str(num_slices): - """ Define environment variables on the TPUS """ + """Define environment variables on the TPUS""" # pylint: disable=line-too-long env_str = """curl -s 'http://metadata.google.internal/computeMetadata/v1/instance/attributes/tpu-env' -H 'Metadata-Flavor: Google' > /tmp/tpu-env # store the metadata NODE_ID=$(grep '^NODE_ID' /tmp/tpu-env | cut -d "'" -f 2) @@ -164,11 +171,13 @@ def get_env_command_str(num_slices): slice_assignment = """SLICE_ID=$(grep '^MEGASCALE_SLICE_ID' /tmp/tpu-env | cut -d "'" -f 2)""" return env_str + "\n" + slice_assignment + def finish_status_str(): # pylint: disable=line-too-long return """multihost_job finished main command on slice $SLICE_ID worker $WORKER_ID at $(date "+%Y-%m-%d %H:%M:%S") UTC with exit status $?. This worker will immediately send its logs to GCS.""" + def create_kill_command_str(args): # pylint: disable=line-too-long return f"""if [[ $SLICE_ID -eq 0 && $WORKER_ID -eq 0 ]]; then @@ -177,12 +186,14 @@ def create_kill_command_str(args): gcloud alpha compute tpus queued-resources delete {args.RUN_NAME} --force --quiet --project={args.PROJECT} --zone={args.ZONE} fi""" + def download_from_gcs(zip_gcs_path): return f""" echo "{write_download_from_gcs_sh(zip_gcs_path)}" > download_from_gcs.sh bash download_from_gcs.sh """ + def write_download_from_gcs_sh(zip_gcs_path): # pylint: disable=anomalous-backslash-in-string return f"""GCS_READ_SUCCESS=0 @@ -197,12 +208,14 @@ def write_download_from_gcs_sh(zip_gcs_path): }} done""" + def setup_ops_str(run_name, log_name): return f""" echo "{install_ops_script_str(run_name, log_name)}" > install_ops_wait_dpkg.sh bash install_ops_wait_dpkg.sh & """ + def install_ops_script_str(run_name, log_name): # pylint: disable=anomalous-backslash-in-string return f"""OPS_FILE=/etc/google-cloud-ops-agent/config.yaml @@ -227,6 +240,7 @@ def install_ops_script_str(run_name, log_name): sudo service google-cloud-ops-agent restart """ + def create_ops_config_str(run_name, log_name): return f"""logging: receivers: @@ -240,6 +254,7 @@ def create_ops_config_str(run_name, log_name): default_pipeline: receivers: [{run_name}_log]""" + def google_cloud_logging_url(run_name, project): # pylint: disable=line-too-long return f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22{run_name}_log%22%2529;?project={project}" @@ -249,46 +264,55 @@ def google_cloud_logging_single_host_url(run_name, project): # pylint: disable=line-too-long return f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22gce_instance%22%20AND%0Alog_id%2528%22{run_name}_log%22%2529%20AND%0Alabels.%22agent.googleapis.com%2Flog_file_path%22%3D%20%22%2F{run_name}%2Fmain_command_log_slice_0_worker_0%22;?project={project}" + def gcs_bucket_url(bucket_name, bucket_dir, project): bucket_path = os.path.join(bucket_name, bucket_dir) return f"https://console.cloud.google.com/storage/browser/{bucket_path}?project={project}" + ################### Main ################### def main(raw_args=None) -> None: - ##### Define flags ##### - parser = argparse.ArgumentParser(description='TPU configuration options') - parser.add_argument('--TPU_TYPE', type=str, default='v4-8', - help='The type of the TPU') - parser.add_argument('--VERSION', type=str, default='tpu-ubuntu2204-base', - help='The runtime version of the TPU') - parser.add_argument('--NUM_SLICES', type=int, default=2, - help='The number of slices to run the job on') - parser.add_argument('--SCRIPT_DIR', type=str, default=os.getcwd(), - help='The local location of the directory to copy to the TPUs and run the main command from. \ - Defaults to current working directory.') - parser.add_argument('--COMMAND', type=str, default=None, required=True, - help='Main command to run on each TPU. \ + ##### Define flags ##### + parser = argparse.ArgumentParser(description="TPU configuration options") + parser.add_argument("--TPU_TYPE", type=str, default="v4-8", help="The type of the TPU") + parser.add_argument("--VERSION", type=str, default="tpu-ubuntu2204-base", help="The runtime version of the TPU") + parser.add_argument("--NUM_SLICES", type=int, default=2, help="The number of slices to run the job on") + parser.add_argument( + "--SCRIPT_DIR", + type=str, + default=os.getcwd(), + help="The local location of the directory to copy to the TPUs and run the main command from. \ + Defaults to current working directory.", + ) + parser.add_argument( + "--COMMAND", + type=str, + default=None, + required=True, + help="Main command to run on each TPU. \ This command is run from a copied version of SCRIPT_DIR on each TPU worker. \ You must include your dependency installations here, \ - e.g. --COMMAND=\'bash setup.sh && python3 train.py\'') - parser.add_argument('--BUCKET_NAME', type=str, default=None, required=True, - help='Name of GCS bucket, e.g. my-bucket') - parser.add_argument('--BUCKET_DIR', type=str, default="", - help='Directory within the GCS bucket, can be None, e.g. my-dir') - parser.add_argument('--PROJECT', type=str, default=None, - help='GCE project name, defaults to gcloud config project') - parser.add_argument('--ZONE', type=str, default=None, - help='GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone') - parser.add_argument('--RUN_NAME', type=str, default=None, - help='Run name used for temporary files, defaults to timestamp.') - parser.add_argument('--CQR_EXTRA_ARGS', type=str, default=None, - help='Additional arguments to be passed verbatim to the CQR request, e.g. \ - --CQR_EXTRA_ARGS="--reserved --service-account=my-service-account-email-address') - parser.add_argument('--ENABLE_AUTOCHECKPOINT', type=bool, default=False, - help='Whether to enable the Autocheckpoint feature') + e.g. --COMMAND='bash setup.sh && python3 train.py'", + ) + parser.add_argument("--BUCKET_NAME", type=str, default=None, required=True, help="Name of GCS bucket, e.g. my-bucket") + parser.add_argument("--BUCKET_DIR", type=str, default="", help="Directory within the GCS bucket, can be None, e.g. my-dir") + parser.add_argument("--PROJECT", type=str, default=None, help="GCE project name, defaults to gcloud config project") + parser.add_argument( + "--ZONE", type=str, default=None, help="GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone" + ) + parser.add_argument("--RUN_NAME", type=str, default=None, help="Run name used for temporary files, defaults to timestamp.") + parser.add_argument( + "--CQR_EXTRA_ARGS", + type=str, + default=None, + help='Additional arguments to be passed verbatim to the CQR request, e.g. \ + --CQR_EXTRA_ARGS="--reserved --service-account=my-service-account-email-address', + ) + parser.add_argument( + "--ENABLE_AUTOCHECKPOINT", type=bool, default=False, help="Whether to enable the Autocheckpoint feature" + ) args = parser.parse_args(raw_args) - print("\nStarting multihost_job...\n", flush=True) #### Parse flags #### @@ -297,7 +321,7 @@ def main(raw_args=None) -> None: if not args.ZONE: args.ZONE = get_zone() if not args.RUN_NAME: - args.RUN_NAME = get_run_name() # Used for QR name, TPU_PREFIX, logging file, and tmp json file. + args.RUN_NAME = get_run_name() # Used for QR name, TPU_PREFIX, logging file, and tmp json file. args = normalize_gcs_bucket_name(args) print_flags(args) @@ -343,17 +367,24 @@ def main(raw_args=None) -> None: print(f"Your job is being logged, follow it here:\n{google_cloud_logging_url(args.RUN_NAME, args.PROJECT)}\n") - print(f"To see the output of a single host, you may edit the slice and worker number in the log_file_path property here:"\ - f"\n{google_cloud_logging_single_host_url(args.RUN_NAME, args.PROJECT)}\n") + print( + f"To see the output of a single host, you may edit the slice and worker number in the log_file_path property here:" + f"\n{google_cloud_logging_single_host_url(args.RUN_NAME, args.PROJECT)}\n" + ) - print(f"When your job is finished, the main command log is in the GCS bucket here:"\ - f"\n{gcs_bucket_url(args.BUCKET_NAME, bucket_dir, args.PROJECT)}\n") + print( + f"When your job is finished, the main command log is in the GCS bucket here:" + f"\n{gcs_bucket_url(args.BUCKET_NAME, bucket_dir, args.PROJECT)}\n" + ) print("View the status of the created TPUs via: ") - print(f"gcloud alpha compute tpus queued-resources list "\ - f"--filter={args.RUN_NAME} --zone={args.ZONE} --project={args.PROJECT}\n") + print( + f"gcloud alpha compute tpus queued-resources list " + f"--filter={args.RUN_NAME} --zone={args.ZONE} --project={args.PROJECT}\n" + ) return 0 -if __name__ == '__main__': + +if __name__ == "__main__": print("Name is __main__") main() diff --git a/multihost_runner.py b/multihost_runner.py index 31d69ef63..0bc126cd2 100644 --- a/multihost_runner.py +++ b/multihost_runner.py @@ -46,45 +46,67 @@ import os import re + ##### Define flags ##### def get_project(): completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) - project_outputs = completed_command.stdout.decode().strip().split('\n') - if len(project_outputs) < 1 or project_outputs[-1]=='': + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": sys.exit("You must specify the project in the PROJECT flag or set it with 'gcloud config set project '") - return project_outputs[-1] # The project name lives on the last line of the output + return project_outputs[-1] # The project name lives on the last line of the output + def get_zone(): completed_command = subprocess.run(["gcloud", "config", "get", "compute/zone"], check=True, capture_output=True) - zone_outputs = completed_command.stdout.decode().strip().split('\n') - if len(zone_outputs) < 1 or zone_outputs[-1]=='': + zone_outputs = completed_command.stdout.decode().strip().split("\n") + if len(zone_outputs) < 1 or zone_outputs[-1] == "": sys.exit("You must specify the zone in the ZONE flag or set it with 'gcloud config set compute/zone '") - return zone_outputs[-1] # The zone name lives on the last line of the output + return zone_outputs[-1] # The zone name lives on the last line of the output + def default_run_name(): now = datetime.now() return now.strftime("%Y-%m-%d-%H-%M-%S") -parser = argparse.ArgumentParser(description='TPU configuration options') -parser.add_argument('--TPU_PREFIX', type=str, default=None, required=True, - help="Prefix of worker TPU's. E.g. if TPU's are named user-0 and user-1, \ - TPU_PREFIX should be set as user") -parser.add_argument('--PROJECT', type=str, default=None, - help='GCE project name, defaults to gcloud config project') -parser.add_argument('--ZONE', type=str, default=None, - help='GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone') -parser.add_argument('--SCRIPT_DIR', type=str, default=os.getcwd(), - help="The local location of the directory to copy to the TPUs and run the main command from. \ - Defaults to current working directory.") -parser.add_argument('--COMMAND', type=str, default=None, required=True, - help="Main command to run on each TPU. \ - This command is run from a copied version of SCRIPT_DIR on each TPU worker.") -parser.add_argument('--RUN_NAME', type=str, default=default_run_name(), - help="Name for the code directory on the TPU") -parser.add_argument('--USE_EXISTING_FOLDER', type=str, default="False", - help='If true, use the existing code directory on the TPU') -parser.add_argument('--INTERNAL_IP', type=str, default="False", - help="Set true if running script locally from a TPU or GCE instance, false otherwise.") + +parser = argparse.ArgumentParser(description="TPU configuration options") +parser.add_argument( + "--TPU_PREFIX", + type=str, + default=None, + required=True, + help="Prefix of worker TPU's. E.g. if TPU's are named user-0 and user-1, \ + TPU_PREFIX should be set as user", +) +parser.add_argument("--PROJECT", type=str, default=None, help="GCE project name, defaults to gcloud config project") +parser.add_argument( + "--ZONE", type=str, default=None, help="GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone" +) +parser.add_argument( + "--SCRIPT_DIR", + type=str, + default=os.getcwd(), + help="The local location of the directory to copy to the TPUs and run the main command from. \ + Defaults to current working directory.", +) +parser.add_argument( + "--COMMAND", + type=str, + default=None, + required=True, + help="Main command to run on each TPU. \ + This command is run from a copied version of SCRIPT_DIR on each TPU worker.", +) +parser.add_argument("--RUN_NAME", type=str, default=default_run_name(), help="Name for the code directory on the TPU") +parser.add_argument( + "--USE_EXISTING_FOLDER", type=str, default="False", help="If true, use the existing code directory on the TPU" +) +parser.add_argument( + "--INTERNAL_IP", + type=str, + default="False", + help="Set true if running script locally from a TPU or GCE instance, false otherwise.", +) args = parser.parse_args() args.USE_EXISTING_FOLDER = args.USE_EXISTING_FOLDER.lower() == "true" args.INTERNAL_IP = args.INTERNAL_IP.lower() == "true" @@ -95,69 +117,90 @@ def default_run_name(): if args.USE_EXISTING_FOLDER is True and not args.RUN_NAME: raise ValueError("When USE_EXISTING_FOLDER is true, RUN_NAME must be specified.") -Slice = namedtuple('Slice', ['name', 'slice_num', 'num_workers', 'version']) +Slice = namedtuple("Slice", ["name", "slice_num", "num_workers", "version"]) + def get_slices(): - """ Returns a list of slices matching TPU_PREFIX """ + """Returns a list of slices matching TPU_PREFIX""" command = [ - "gcloud", "alpha", "compute", "tpus", "tpu-vm", "list", - f"--filter=name~{args.TPU_PREFIX}", "--format=csv(name,accelerator_type)", - f"--project={args.PROJECT}", f"--zone={args.ZONE}" + "gcloud", + "alpha", + "compute", + "tpus", + "tpu-vm", + "list", + f"--filter=name~{args.TPU_PREFIX}", + "--format=csv(name,accelerator_type)", + f"--project={args.PROJECT}", + f"--zone={args.ZONE}", ] try: completed_command = subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: - print(f"Error occurred trying to find TPU slices named {args.TPU_PREFIX} or matching regex \n {args.TPU_PREFIX}-[0-9]+ " - f"in project {args.PROJECT} zone {args.ZONE}") + print( + f"Error occurred trying to find TPU slices named {args.TPU_PREFIX} or matching regex \n {args.TPU_PREFIX}-[0-9]+ " + f"in project {args.PROJECT} zone {args.ZONE}" + ) print(f"Error is:\n {e.stderr}") return [] instances = completed_command.stdout.decode() - instance_list = instances.strip().split('\n') - instance_list = filter_instances(instance_list[1:], args.TPU_PREFIX) # First row is headers + instance_list = instances.strip().split("\n") + instance_list = filter_instances(instance_list[1:], args.TPU_PREFIX) # First row is headers num_slices = len(instance_list) slices = [None for _ in range(num_slices)] if num_slices > 0: print(f"{num_slices} slices found.", flush=True) else: - print(f"No TPUs found with name {args.TPU_PREFIX} or matching regex {args.TPU_PREFIX}-[0-9]+ " - "in project {args.PROJECT} and zone {args.ZONE}.") + print( + f"No TPUs found with name {args.TPU_PREFIX} or matching regex {args.TPU_PREFIX}-[0-9]+ " + "in project {args.PROJECT} and zone {args.ZONE}." + ) return [] - slice_names = [instance.split(',')[0] for instance in instance_list] - slice_versions = [instance.split(',')[1] for instance in instance_list] + slice_names = [instance.split(",")[0] for instance in instance_list] + slice_versions = [instance.split(",")[1] for instance in instance_list] # Get number of workers in any slice (assume same worker count for all slices.) command = [ - "gcloud", "compute", "tpus", "describe", slice_names[0], - "--flatten=networkEndpoints[]", "--format=csv[no-heading](networkEndpoints.ipAddress)", - f"--project={args.PROJECT}", f"--zone={args.ZONE}" + "gcloud", + "compute", + "tpus", + "describe", + slice_names[0], + "--flatten=networkEndpoints[]", + "--format=csv[no-heading](networkEndpoints.ipAddress)", + f"--project={args.PROJECT}", + f"--zone={args.ZONE}", ] completed_command = subprocess.run(command, capture_output=True, check=True) - num_workers = len(completed_command.stdout.decode().strip().split('\n')) + num_workers = len(completed_command.stdout.decode().strip().split("\n")) for slice_name, version in zip(slice_names, slice_versions): if num_slices > 1: - slice_num = int(slice_name.split('-')[-1]) + slice_num = int(slice_name.split("-")[-1]) else: slice_num = 0 slices[slice_num] = Slice(slice_name, slice_num, num_workers, version) return slices + def filter_instances(instance_list, tpu_prefix): # First look for exact match with tpu_prefix for instance in instance_list: - if instance.split(',')[0] == tpu_prefix: + if instance.split(",")[0] == tpu_prefix: return [instance] # If no exact match, reg-exp full match "-[0-9]+" re_pattern = tpu_prefix + "-[0-9]+" - return [instance for instance in instance_list if re.fullmatch(re_pattern, instance.split(',')[0])] + return [instance for instance in instance_list if re.fullmatch(re_pattern, instance.split(",")[0])] + def write_kill_script(kill_processes_script_name): kill_processes_script = os.path.join(args.SCRIPT_DIR, kill_processes_script_name) with open(kill_processes_script, "w", encoding="utf-8") as f: f.write(kill_existing_processes_str()) + def kill_existing_processes_str(): return """#!/bin/bash _TPU_VERSION_NAME="${1}" @@ -181,17 +224,18 @@ def kill_existing_processes_str(): fi sudo rm -f /tmp/libtpu_lockfile""" + def scps(slices, run_name_dir, zip_name): - """ Zip the script directory, scp it to the TPUs, and unzip it there. """ + """Zip the script directory, scp it to the TPUs, and unzip it there.""" original_working_directory = os.getcwd() - os.chdir(args.SCRIPT_DIR) # To tar script_dir, it is most convenient to cd there. + os.chdir(args.SCRIPT_DIR) # To tar script_dir, it is most convenient to cd there. # Zip script directory # Save the zip both to the logging directory, and the script directory. # It will be removed from the script directory after the transfer to the TPUs os.makedirs(run_name_dir, exist_ok=True) zip_path = os.path.join(run_name_dir, zip_name) - command = ["tar","--exclude=tmp", "-czf", zip_path, "./"] + command = ["tar", "--exclude=tmp", "-czf", zip_path, "./"] subprocess.run(command, check=True) # Move zip file to each tpuvm worker @@ -200,8 +244,17 @@ def scps(slices, run_name_dir, zip_name): for cur_slice in slices: for worker_num in range(cur_slice.num_workers): command = [ - "gcloud", "compute", "tpus", "tpu-vm", "scp", f"--worker={worker_num}", zip_path, - f"{cur_slice.name}:~/", "--strict-host-key-checking=no", f"--project={args.PROJECT}", f"--zone={args.ZONE}" + "gcloud", + "compute", + "tpus", + "tpu-vm", + "scp", + f"--worker={worker_num}", + zip_path, + f"{cur_slice.name}:~/", + "--strict-host-key-checking=no", + f"--project={args.PROJECT}", + f"--zone={args.ZONE}", ] if args.INTERNAL_IP: command.append("--internal-ip") @@ -217,15 +270,16 @@ def scps(slices, run_name_dir, zip_name): return return_code + def execute_main_command(main_command, slices, local_log_dir, zip_name): - """ Run the main command on each worker, logging each separately. """ - kill_script_name = "kill_existing_processes.sh" # File written on worker machines + """Run the main command on each worker, logging each separately.""" + kill_script_name = "kill_existing_processes.sh" # File written on worker machines commands = [] output_logs = [] worker_list = [] os.makedirs(local_log_dir, exist_ok=True) - for slice_num, cur_slice in enumerate(slices): + for slice_num, cur_slice in enumerate(slices): for worker_num in range(cur_slice.num_workers): output_filename = os.path.join(local_log_dir, f"output_slice_{cur_slice.slice_num:04d}_worker_{worker_num:04d}.txt") output_logs.append(output_filename) @@ -237,15 +291,33 @@ def execute_main_command(main_command, slices, local_log_dir, zip_name): kill_existing_command = f"bash {kill_script_name} {cur_slice.version}" if args.USE_EXISTING_FOLDER is False: - remote_command_list = [mkdir_command , mv_zip_command , cd_command , unzip_command , - write_kill_script_command , kill_existing_command , main_command] + remote_command_list = [ + mkdir_command, + mv_zip_command, + cd_command, + unzip_command, + write_kill_script_command, + kill_existing_command, + main_command, + ] else: - remote_command_list = [cd_command, write_kill_script_command , kill_existing_command , main_command] + remote_command_list = [cd_command, write_kill_script_command, kill_existing_command, main_command] remote_command_list_str = " && ".join(remote_command_list) - gcloud_command=[ - "gcloud", "alpha", "compute", "tpus", "tpu-vm", "ssh", cur_slice.name, f"--worker={worker_num}", - "--command", remote_command_list_str, "--strict-host-key-checking=no", - f"--project={args.PROJECT}", f"--zone={args.ZONE}"] + gcloud_command = [ + "gcloud", + "alpha", + "compute", + "tpus", + "tpu-vm", + "ssh", + cur_slice.name, + f"--worker={worker_num}", + "--command", + remote_command_list_str, + "--strict-host-key-checking=no", + f"--project={args.PROJECT}", + f"--zone={args.ZONE}", + ] if args.INTERNAL_IP: gcloud_command.append("--internal-ip") commands.append(gcloud_command) @@ -254,12 +326,16 @@ def execute_main_command(main_command, slices, local_log_dir, zip_name): return_code, return_codes = run_commands(commands, 0, "MAIN COMMAND", worker_list, output_logs=output_logs) if return_code > 0: failure_index = next((i for i, x in enumerate(return_codes) if x), None) - print(f"Main command failed on slice {worker_list[failure_index][0]} worker"\ - f" {worker_list[failure_index][1]} with error code {return_codes[failure_index]}, see logs for details", flush=True) + print( + f"Main command failed on slice {worker_list[failure_index][0]} worker" + f" {worker_list[failure_index][1]} with error code {return_codes[failure_index]}, see logs for details", + flush=True, + ) return return_code + def run_commands(commands, id_to_print, jobname, worker_list, is_shell=False, output_logs=None, fail_fast=True): - ''' Runs commands in parallel. + """Runs commands in parallel. Inputs: commands: list of n commands, each command is a a list of strings id_to_print: which command is printed to the terminal, typically 0 or None @@ -268,7 +344,7 @@ def run_commands(commands, id_to_print, jobname, worker_list, is_shell=False, ou is_shell: Boolean directly passed as shell argument to subprocess.Popen output_logs: list of n log paths, each command will output to each log. fail_fast: If true, when one command fails immediately terminate others - ''' + """ children = [] start_time = datetime.now() @@ -287,7 +363,7 @@ def run_commands(commands, id_to_print, jobname, worker_list, is_shell=False, ou while True: returncodes = [child.poll() for child in children] - max_returncode = max([0]+[r for r in returncodes if r is not None]) + max_returncode = max([0] + [r for r in returncodes if r is not None]) completed = len([r for r in returncodes if r is not None]) total = len(returncodes) seconds_elapsed = (datetime.now() - start_time).total_seconds() @@ -300,8 +376,7 @@ def run_commands(commands, id_to_print, jobname, worker_list, is_shell=False, ou print(f"[t={seconds_elapsed:.2f}, {jobname}] Completed {completed}/{total}{slow_str}...") if seconds_elapsed >= 60 and not 0 in returncodes and jobname == "SCP": - print("SCP operation timed out - terminating all processes."\ - " Please check that --INTERNAL_IP flag is set correctly.") + print("SCP operation timed out - terminating all processes." " Please check that --INTERNAL_IP flag is set correctly.") for child in children: child.terminate() max_returncode = 255 @@ -319,14 +394,17 @@ def run_commands(commands, id_to_print, jobname, worker_list, is_shell=False, ou time.sleep(1) return max_returncode, returncodes + def assert_script_dir_exists(script_dir): if not os.path.isdir(script_dir): sys.exit(f"No directory named {script_dir} found.") + class Tee: - """ Helper class to print subprocess to both stdout and a log file. """ + """Helper class to print subprocess to both stdout and a log file.""" + def __init__(self, *files, bufsize=1): - files = [x.fileno() if hasattr(x, 'fileno') else x for x in files] + files = [x.fileno() if hasattr(x, "fileno") else x for x in files] read_fd, write_fd = os.pipe() pid = os.fork() if pid: @@ -353,6 +431,7 @@ def close(self): os.close(self._fileno) os.waitpid(self.child_pid, 0) + ################### Main ################### def main() -> None: print("Starting multihost runner...", flush=True) @@ -392,5 +471,6 @@ def main() -> None: print(f"Main command finished with errors, check the logs located in: {local_log_dir}", flush=True) return return_code -if __name__ == '__main__': + +if __name__ == "__main__": main()