diff --git a/MaxText/configs/gpu_smoke_test.yml b/MaxText/configs/gpu_smoke_test.yml index cafee5387..b9d34c74a 100644 --- a/MaxText/configs/gpu_smoke_test.yml +++ b/MaxText/configs/gpu_smoke_test.yml @@ -3,7 +3,6 @@ base_config: "base.yml" hardware: "gpu" attention: "dot_product" base_emb_dim: 8 -base_emb_dim: 8 base_num_query_heads: 4 base_num_kv_heads: 4 base_mlp_dim: 32 diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 0f6d6111c..4363874be 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -89,8 +89,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name "checkpoint_period=1", "async_checkpointing=false", ] - pyconfig.initialize(base_args) - cfg = pyconfig.config + cfg = pyconfig.initialize(base_args) init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2) devices_array = max_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) diff --git a/MaxText/decode.py b/MaxText/decode.py index 2e07228b7..8b114d16d 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -30,8 +30,7 @@ def main(argv: Sequence[str]) -> None: jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) validate_config(config) max_utils.print_system_information() diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index a8c8fddfc..b73f94be7 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -142,8 +142,8 @@ def generate_decode_checkpoint(config): def main(argv: Sequence[str]) -> None: print(argv) - pyconfig.initialize(argv) - generate_decode_checkpoint(pyconfig.config) + config = pyconfig.initialize(argv) + generate_decode_checkpoint(config) if __name__ == "__main__": diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index e9bb5b6ed..a30da7e7e 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -356,8 +356,7 @@ def run_benchmarks(config): def main(argv): jax.config.update("jax_default_prng_impl", "unsafe_rbg") - pyconfig.initialize(argv) - run_benchmarks(pyconfig.config) + run_benchmarks(pyconfig.initialize(argv)) if __name__ == "__main__": diff --git a/MaxText/inference_microbenchmark_sweep.py b/MaxText/inference_microbenchmark_sweep.py index 82c541b79..9b764343a 100644 --- a/MaxText/inference_microbenchmark_sweep.py +++ b/MaxText/inference_microbenchmark_sweep.py @@ -45,8 +45,7 @@ def main(): - flatten_microbenchmark_results: Whether or not to flatten results. Should be true """ - pyconfig.initialize(sys.argv) - config = pyconfig.config + config = pyconfig.initialize(sys.argv) base_run_name = config.run_name with open(config.inference_metadata_file, encoding="utf-8") as json_file: diff --git a/MaxText/inference_mlperf/requirements.txt b/MaxText/inference_mlperf/requirements.txt index 4a72ef22e..51ae8c5bf 100644 --- a/MaxText/inference_mlperf/requirements.txt +++ b/MaxText/inference_mlperf/requirements.txt @@ -5,3 +5,4 @@ absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 +omegaconf diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index acbbb40a2..4695668a7 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -258,11 +258,11 @@ def convert_orbax_hf(hf_model_path, config): def main(argv: Sequence[str]): - pyconfig.initialize(argv[:-1]) + config = pyconfig.initialize(argv[:-1]) hf_model_path = argv[-1].split("=")[1] print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}") - convert_orbax_hf(hf_model_path, pyconfig.config) + convert_orbax_hf(hf_model_path, config) if __name__ == "__main__": diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index ef7412ca0..ae89d12ed 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -54,8 +54,7 @@ class MaxEngineConfig: """Engine specific config class to allow using multiple MaxEngine instances in an inference run. - The default pyconfig.config is a global param shared across multiple instances and doesn't - allow using different config for each MaxEngine instance. + TODO: evaluate the need for this given the restructured pyconfig.py """ def __init__(self, keys): diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index e45c3b6b0..74d689c80 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -63,6 +63,5 @@ def main(config): if __name__ == "__main__": jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(sys.argv) - cfg = pyconfig.config + cfg = pyconfig.initialize(sys.argv) main(cfg) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 524f343bd..5c0d5c007 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -28,7 +28,9 @@ import accelerator_to_spec_map import max_logging import max_utils -import yaml +import omegaconf + +OmegaConf = omegaconf.OmegaConf # pylint: disable=line-too-long @@ -64,7 +66,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None: if s not in valid_kv_quant_axis: # currently supported kv_quant_axis raise ValueError("Invalid kv_quant_axis was passed. Valid options ", valid_kv_quant_axis) if quantize_kvcache and s == "": - raise ValueError("kv_quant_axis can not be '' when quantize_kvcache is True") + raise ValueError("kv_quant_axis cannot be '' when quantize_kvcache is True") def validate_attention_kernel(s: str) -> None: @@ -92,7 +94,7 @@ def validate_periodic_profiler(profiler, profile_periodically_period, profiler_s raise ValueError("Periodic profiler requested but no profiler was set, set it via profiler=xplane or profiler=nsys") if profile_periodically_period < profiler_steps: raise ValueError( - f"You must set the profile_periodically_period {profile_periodically_period} at least as long profiler_steps {profiler_steps}." + f"You must set the profile_periodically_period {profile_periodically_period} at least as long as profiler_steps {profiler_steps}." ) @@ -108,8 +110,8 @@ def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_leng if max_target_length < max_prefill_length: # valid max_target_length = max_prefill_length for existing logit checks raise ValueError( - f"Invalid max_target_length {max_target_length}, this should be sum of " - f"max_prefill_predict_length ({max_prefill_length}) and max output length expected." + f"Invalid max_target_length {max_target_length}, this should be the sum of " + f"max_prefill_predict_length ({max_prefill_length}) and the expected max output length." ) @@ -266,16 +268,14 @@ def validate_and_assign_remat_tensors(keys): return keys -_config = None -config = None - - def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]: return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l class _HyperParameters: # pylint: disable=missing-class-docstring + # This class is responsible for loading, merging, and overriding the configuration. + def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]): for environment_var in os.environ: if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX: @@ -285,14 +285,16 @@ def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]): if not environment_var[len(_MAX_PREFIX) :].isupper(): raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.") - def _load_kwargs(self, argv: list[str], **kwargs): - args_dict = dict(a.split("=", 1) for a in argv[2:]) - args_dict.update(kwargs) - return args_dict - def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]: - """Update model config from environment and command line""" - raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs) + """Update model config from environment and command line using OmegaConf overrides.""" + # Use OmegaConf.from_cli to capture CLI arguments. + cli_cfg = OmegaConf.from_cli(argv[2:]) + # Also create a configuration from any extra keyword arguments. + kwargs_cfg = OmegaConf.create(kwargs) + # Merge command-line and keyword arguments. + cmdline_cfg = OmegaConf.merge(cli_cfg, kwargs_cfg) + raw_data_from_cmd_line = OmegaConf.to_container(cmdline_cfg, resolve=True) + updated_keys = [] for k in raw_data_from_cmd_line: @@ -303,7 +305,7 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") - if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ: + if k not in raw_data_from_cmd_line and yaml_key_to_env_key(k) not in os.environ: raw_keys[k] = raw_data_from_yaml[k] continue @@ -334,9 +336,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, return updated_keys def _load_config(self, config_name: str) -> dict[str, Any]: - """Loads the YAML config from a file with a given name.""" - with open(config_name, "r", encoding="utf-8") as yaml_file: - raw_data_from_yaml = yaml.safe_load(yaml_file) + """Loads the YAML config from a file using OmegaConf, and resolves inheritance.""" + base_cfg = OmegaConf.load(config_name) + raw_data_from_yaml = OmegaConf.to_container(base_cfg, resolve=True) # Load data from parent config. Note that inheritance has override # semantics, and the path is relative to the current config. @@ -351,6 +353,7 @@ def _load_config(self, config_name: str) -> dict[str, Any]: loaded_parent_config_filename = parent_config_filename base_config = self._load_config(loaded_parent_config_filename) + # Override base_config with values from raw_data_from_yaml. for key, value in raw_data_from_yaml.items(): base_config[key] = value return base_config @@ -454,7 +457,10 @@ def user_init(raw_keys): raw_keys["global_batch_size_to_eval_on"], raw_keys["micro_batch_size_to_eval_on"], ) = calculate_global_batch_sizes( - raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1 + raw_keys["eval_per_device_batch_size"], + raw_keys["expansion_factor_real_data"], + get_num_target_devices(raw_keys), + 1, ) raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys) @@ -511,9 +517,10 @@ def update_model_vars(base_config_path, raw_keys, config_name: str): if not os.path.isfile(file_path): dir_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml") - with open(file_path, "r", encoding="utf-8") as file: - model_vars = yaml.safe_load(file) - updated_keys = list(model_vars.keys()) + # Use OmegaConf to load the model-specific configuration. + model_vars = OmegaConf.load(file_path) + model_vars = OmegaConf.to_container(model_vars, resolve=True) + updated_keys = list(model_vars.keys()) raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) return updated_keys @@ -856,30 +863,33 @@ def using_expert_parallelism(raw_keys) -> bool: return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1 -class HyperParameters: # pylint: disable=missing-class-docstring +class HyperParameters: + """Wrapper class to expose the configuration in a read-only manner.""" - def __init__(self): - pass + def __init__(self, config): + object.__setattr__(self, "_config", config) def __getattr__(self, attr): - if attr not in _config.keys: - raise ValueError(f"Requested key {attr}, not in config") - return _config.keys[attr] + try: + # Attempt to perform the normal lookup + return object.__getattribute__(self, "_config").keys[attr] + except AttributeError as exc: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") from exc def __setattr__(self, attr, value): - raise ValueError + raise ValueError("Reinitialization of config is not allowed") def get_keys(self): - return _config.keys + return self._config.keys def initialize(argv, **kwargs): - global _config, config _config = _HyperParameters(argv, **kwargs) - config = HyperParameters() + config = HyperParameters(_config) + return config if __name__ == "__main__": - initialize(sys.argv) - print(config.steps) - r = range(config.steps) + main_config = initialize(sys.argv) + print(main_config.steps) + r = range(main_config.steps) diff --git a/MaxText/scratch_code/mixtral-numerical-verification.ipynb b/MaxText/scratch_code/mixtral-numerical-verification.ipynb index 0d5526305..0a6018058 100644 --- a/MaxText/scratch_code/mixtral-numerical-verification.ipynb +++ b/MaxText/scratch_code/mixtral-numerical-verification.ipynb @@ -45,7 +45,7 @@ "import pyconfig\n", "from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n", "\n", - "pyconfig.initialize(\n", + "config_maxtext = pyconfig.initialize(\n", " [None, \"configs/base.yml\"],\n", " base_emb_dim=4096,\n", " base_num_query_heads=32,\n", @@ -73,7 +73,6 @@ " capacity_factor=-1,\n", " scan_layers=False,\n", ")\n", - "config_maxtext = pyconfig.config\n", "\n", "config_hf = MixtralConfig(\n", " vocab_size=config_maxtext.vocab_size,\n", diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 61864efdc..4759cb50a 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -106,8 +106,7 @@ def add_entropy_to_checkpoint(state): def main(argv: Sequence[str]) -> None: jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) validate_train_config(config) print(f"Found {jax.device_count()} devices.") print(f"Found {jax.process_count()} processes.") diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index 5e7e3447f..a0c0f558c 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -60,8 +60,7 @@ def data_load_loop(config, state=None): def main(argv: Sequence[str]) -> None: jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") max_logging.log(f"Found {jax.process_count()} processes.") diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index ea3ea247b..5938ac219 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -42,7 +42,7 @@ class AttentionTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -50,7 +50,7 @@ def setUp(self): max_target_length=128, max_prefill_predict_length=16, ) - self.cfg = pyconfig.config + self.cfg = config self.rng = jax.random.PRNGKey(0) devices_array = max_utils.create_device_mesh(self.cfg) @@ -336,7 +336,7 @@ def _dot_product_attention( rtol, atol = 1e-02, 1e-02 - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -345,7 +345,6 @@ def _dot_product_attention( max_prefill_predict_length=16, attention="dot_product", ) - config = pyconfig.config prefill_length = config.max_prefill_predict_length decode_total_length = config.max_target_length @@ -437,7 +436,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): rtol, atol = 1e-02, 1e-02 - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -446,7 +445,6 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): max_prefill_predict_length=16, attention="dot_product", ) - config = pyconfig.config prefill_length = config.max_prefill_predict_length decode_total_length = config.max_target_length @@ -727,7 +725,7 @@ class MLATest(parameterized.TestCase): def init_mla(self, rope_type): """Helper function to initialize MLA with different model names.""" - pyconfig.initialize( + cfg = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -737,7 +735,6 @@ def init_mla(self, rope_type): attention_type=attentions.AttentionType.MLA.value, rope_type=rope_type, ) - cfg = pyconfig.config rng = jax.random.PRNGKey(0) devices_array = max_utils.create_device_mesh(cfg) diff --git a/MaxText/tests/forward_pass_logit_checker.py b/MaxText/tests/forward_pass_logit_checker.py index 8e9856e74..d41b934f1 100644 --- a/MaxText/tests/forward_pass_logit_checker.py +++ b/MaxText/tests/forward_pass_logit_checker.py @@ -166,6 +166,5 @@ def main(config, test_args): # pylint: disable=W0621 for arg in to_remove_args: model_args = [s for s in model_args if not s.startswith(arg)] - pyconfig.initialize(model_args) - cfg = pyconfig.config + cfg = pyconfig.initialize(model_args) main(cfg, test_args) diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index 8efeb6366..d0f87457a 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -60,15 +60,13 @@ class GPT3(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize( + self.cfg = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], run_name="test", enable_checkpointing=False, model_name="gpt3-52k", dtype="float32", ) - - self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(1234) devices_array = max_utils.create_device_mesh(self.cfg) diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index d2034876d..54b516bca 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -40,7 +40,7 @@ def setUpClass(cls): def setUp(self): super().setUp() - pyconfig.initialize( + self.config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1, run_name="test", @@ -53,7 +53,6 @@ def setUp(self): tokenizer_path="../assets/tokenizer", enable_checkpointing=False, ) - self.config = pyconfig.config self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) self.process_indices = input_pipeline_interface.get_process_loading_real_data( diff --git a/MaxText/tests/hf_data_processing_test.py b/MaxText/tests/hf_data_processing_test.py index 771da1037..55a80ee46 100644 --- a/MaxText/tests/hf_data_processing_test.py +++ b/MaxText/tests/hf_data_processing_test.py @@ -31,7 +31,7 @@ class HfDataProcessingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1, run_name="test", @@ -45,7 +45,7 @@ def setUp(self): tokenizer_path="google-t5/t5-large", enable_checkpointing=False, ) - self.config = pyconfig.config + self.config = config self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) self.process_indices = input_pipeline_interface.get_process_loading_real_data( diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index 0c3dc4fac..139003993 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -25,7 +25,7 @@ class Inference_Microbenchmark(unittest.TestCase): @pytest.mark.tpu_only def test(self): - pyconfig.initialize( + config = pyconfig.initialize( [ None, "configs/tpu_smoke_test.yml", @@ -38,7 +38,7 @@ def test(self): "weight_dtype=bfloat16", ] ) - run_benchmarks(pyconfig.config) + run_benchmarks(config) if __name__ == "__main__": diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index 163f7c591..df0770563 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -116,8 +116,7 @@ def __call__(self, x, y): class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase): def setUp(self): - pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False) - self.config = pyconfig.config + self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False) self.model = ModelWithMultipleCollections() self.key1, self.key2, self.key3 = random.split(random.key(0), num=3) self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length)) @@ -152,8 +151,7 @@ class MaxUtilsInitTransformerState(unittest.TestCase): """Tests initialization of transformer states in max_utils.py""" def setUp(self): - pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False) - self.config = pyconfig.config + self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False) devices_array = max_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) diff --git a/MaxText/tests/maxengine_test.py b/MaxText/tests/maxengine_test.py index bd9cb51db..aad9cbeed 100644 --- a/MaxText/tests/maxengine_test.py +++ b/MaxText/tests/maxengine_test.py @@ -43,7 +43,7 @@ def setUp(self): self.rng = jax.random.PRNGKey(0) def init_pyconfig(self, **kwargs): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -57,7 +57,7 @@ def init_pyconfig(self, **kwargs): max_prefill_predict_length=4, **kwargs, ) - return pyconfig.config + return config def get_data(self): s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) @@ -71,12 +71,11 @@ def get_data(self): return ids, decoder_segment_ids, decoder_positions def test_stack_and_unstack_prefill_cache(self): - pyconfig.initialize( + config = pyconfig.initialize( [None, "configs/base.yml"], enable_checkpointing=False, stack_prefill_result_cache=True, ) - config = pyconfig.config engine = MaxEngine(config, jax.devices()) num_layers = engine.config.num_decoder_layers input = { diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index ed1eecdb6..053ae8665 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -42,7 +42,7 @@ def setUp(self): self.rng = jax.random.PRNGKey(0) def init_pyconfig(self, **kwargs): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1.0, run_name="test", @@ -56,7 +56,7 @@ def init_pyconfig(self, **kwargs): max_prefill_predict_length=4, **kwargs, ) - return pyconfig.config + return config def get_data(self): s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index d6a60afa2..e146671eb 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -38,7 +38,7 @@ class TokenDroppingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize( + self.cfg = pyconfig.initialize( [None, "configs/base.yml"], run_name="token_dropping_test", enable_checkpointing=False, @@ -50,7 +50,6 @@ def setUp(self): per_device_batch_size=1, capacity_factor=2, ) - self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(42) devices_array = max_utils.create_device_mesh(self.cfg) self.model = linears.MoeBlock( @@ -263,7 +262,7 @@ def get_moe_output(self, variables, hidden_states, cfg, mesh): @pytest.mark.tpu_only def test_megablox(self): - pyconfig.initialize( + cfg = pyconfig.initialize( [None, "configs/base.yml"], run_name="moe_block_megablox_test", enable_checkpointing=False, @@ -274,7 +273,6 @@ def test_megablox(self): per_device_batch_size=4, ) - cfg = pyconfig.config rng = jax.random.PRNGKey(1234) rng_model, rng_hidden_states = jax.random.split(rng) hidden_states = jax.random.uniform( @@ -289,7 +287,7 @@ def test_megablox(self): @pytest.mark.tpu_only def test_dense(self): - pyconfig.initialize( + cfg = pyconfig.initialize( [None, "configs/base.yml"], run_name="moe_block_dense_test", enable_checkpointing=False, @@ -300,7 +298,6 @@ def test_dense(self): per_device_batch_size=4, ) - cfg = pyconfig.config rng = jax.random.PRNGKey(2345) rng_model, rng_hidden_states = jax.random.split(rng) hidden_states = jax.random.uniform( diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index 297d75370..78b22d812 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -35,7 +35,7 @@ class MultihostDataloadingTest(unittest.TestCase): def setUp(self): super().setUp() batch_size = 4 - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1, run_name="test", @@ -46,7 +46,6 @@ def setUp(self): dataset_path="gs://maxtext-dataset/", enable_checkpointing=False, ) - config = pyconfig.config global_data_shape = PartitionSpec(batch_size, config.max_target_length) data_sharding = ("data",) mesh_shape_1d = (len(jax.devices()),) diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 292aa4151..d76052632 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -160,7 +160,7 @@ def regular_sequential_layers_dummy_loss( @pytest.mark.tpu_only def test_circular_minimum_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="circular_minimum_microbatches", @@ -171,13 +171,12 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): num_pipeline_microbatches=4, per_device_batch_size=4, ) - config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu_only def test_circular_extra_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="circular_extra_microbatches", @@ -188,13 +187,12 @@ def test_circular_extra_microbatches_same_output_and_grad(self): num_pipeline_microbatches=8, per_device_batch_size=4, ) - config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu_only def test_circular_ag_once(self): # 2 stages, 8 microbatches, all gather once - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="circular_ag_once", @@ -206,13 +204,12 @@ def test_circular_ag_once(self): per_device_batch_size=4, pipeline_fsdp_ag_once=True, ) - config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu_only def test_non_circular_same_output_and_grad(self): # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="non_circular", @@ -223,7 +220,6 @@ def test_non_circular_same_output_and_grad(self): num_pipeline_microbatches=4, per_device_batch_size=4, ) - config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu_only @@ -259,7 +255,7 @@ def test_full_train_circular(self): @pytest.mark.tpu_only def test_delay_activation_forwarding_same_output_and_grad(self): # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="activation_forwarding", @@ -271,7 +267,6 @@ def test_delay_activation_forwarding_same_output_and_grad(self): per_device_batch_size=4, pipeline_delay_activation_forwarding=True, ) - config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) @pytest.mark.tpu_only diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index 6eb0568f3..70af6c2e3 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -29,7 +29,7 @@ class ProfilerTest(unittest.TestCase): # These periodic proilfer tests can run on any platform (cpu, gpu or tpu) @pytest.mark.tpu_only def test_periodic_profiler_third_period_starts(self): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="test_periodic_profiler_starts_after_regular_profile", @@ -38,7 +38,6 @@ def test_periodic_profiler_third_period_starts(self): profiler_steps=4, profile_periodically_period=5, ) - config = pyconfig.config prof = profiler.Profiler(config, offset_step=2) step = 24 # 3 * 5 + 7 + 2: 3 periods of 5 after skipping initial 7 skip + 2 offset. @@ -46,7 +45,7 @@ def test_periodic_profiler_third_period_starts(self): @pytest.mark.tpu_only def test_periodic_profiler_not_start_middle_period(self): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="test_periodic_profiler_starts_after_regular_profile", @@ -55,7 +54,6 @@ def test_periodic_profiler_not_start_middle_period(self): profiler_steps=4, profile_periodically_period=5, ) - config = pyconfig.config prof = profiler.Profiler(config, offset_step=2) step = 25 # This corresponds to the middle of period 3 which started at step 24. @@ -63,7 +61,7 @@ def test_periodic_profiler_not_start_middle_period(self): @pytest.mark.tpu_only def test_periodic_profiler_third_period_ends(self): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="test_periodic_profiler_starts_after_regular_profile", @@ -72,7 +70,6 @@ def test_periodic_profiler_third_period_ends(self): profiler_steps=4, profile_periodically_period=5, ) - config = pyconfig.config prof = profiler.Profiler(config, offset_step=2) step = 27 # 3 * 5 + 4 + 7 + 2: 3 periods of 5, profile takes 4 steps + skipping initial 7 skip + 2 offset @@ -80,7 +77,7 @@ def test_periodic_profiler_third_period_ends(self): @pytest.mark.tpu_only def test_periodic_profiler_third_period_middle_not_end(self): - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], enable_checkpointing=False, run_name="test_periodic_profiler_starts_after_regular_profile", @@ -89,7 +86,6 @@ def test_periodic_profiler_third_period_middle_not_end(self): profiler_steps=4, profile_periodically_period=5, ) - config = pyconfig.config prof = profiler.Profiler(config, offset_step=2) step = 28 # Corresponds to 1 after the third period ended. diff --git a/MaxText/tests/pyconfig_test.py b/MaxText/tests/pyconfig_test.py index bae6600dd..a147b5265 100644 --- a/MaxText/tests/pyconfig_test.py +++ b/MaxText/tests/pyconfig_test.py @@ -63,3 +63,45 @@ def test_logical_axis_partial_override(self): "logical_axis_rules": [("activation", ("data", "fsdp")), ("norm", "fsdp")], }, ) + + def test_multiple_unmodifiable_configs(self): + config_train = pyconfig.initialize( + ["train.py", "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + base_num_decoder_layers=2, + attention="dot_product", + max_target_length=16, + base_emb_dim=256, + base_num_query_heads=2, + base_num_kv_heads=2, + max_prefill_predict_length=4, + ici_tensor_parallelism=-1, + ici_fsdp_parallelism=4, + ) + config_inference = pyconfig.initialize( + ["decode.py", "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + base_num_decoder_layers=2, + attention="dot_product", + max_target_length=16, + base_emb_dim=256, + base_num_query_heads=2, + base_num_kv_heads=2, + max_prefill_predict_length=4, + ici_tensor_parallelism=4, + ici_fsdp_parallelism=-1, + ) + self.assertNotEqual( + config_train.ici_tensor_parallelism, + config_inference.ici_tensor_parallelism, + ) + with self.assertRaises(ValueError): + config_inference.__setattr__("ici_fsdp_parallelism", 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index c399855bd..1d24e420a 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -50,14 +50,13 @@ def __call__(self, inputs): def _configure_quantization(quant_str="", quant_cfg_path="", mode_str="train", replicate_scale=False): - pyconfig.initialize( + config = pyconfig.initialize( [None, "configs/base.yml"], enable_checkpointing=False, quantization=quant_str, quant_cfg_path=quant_cfg_path, replicate_quant_scale=replicate_scale, ) - config = pyconfig.config quant = quantizations.configure_quantization(config, mode_str) return quant diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index a89d99dd2..998bacee3 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -34,7 +34,7 @@ class TfdsDataProcessingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize( + config = pyconfig.initialize( [sys.argv[0], "configs/base.yml"], per_device_batch_size=1, run_name="test", @@ -47,8 +47,8 @@ def setUp(self): enable_checkpointing=False, eval_interval=10, ) - os.environ["TFDS_DATA_DIR"] = pyconfig.config.dataset_path - self.config = pyconfig.config + os.environ["TFDS_DATA_DIR"] = config.dataset_path + self.config = config self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) self.process_indices = input_pipeline_interface.get_process_loading_real_data( diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index 579e23310..ac4525e7c 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -38,8 +38,7 @@ def get_weights(self, argv): """Gets model weights""" # Setup necessary inputs to build a model state - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) quant = quantizations.configure_quantization(config) devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) diff --git a/MaxText/train.py b/MaxText/train.py index d5afc8d61..76d061eef 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -27,7 +27,7 @@ import time import queue -from typing import Sequence, Optional +from typing import Sequence from absl import app from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -199,7 +199,7 @@ def save_checkpoint( state, dataset_type="c4", data_iterator=None, - config: Optional[pyconfig.config] = None, + config=None, ) -> bool: """Wrapper for saving checkpoint.""" if config and config.enable_checkpointing: @@ -1000,8 +1000,7 @@ def main(argv: Sequence[str]) -> None: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) max_utils.print_system_information() validate_train_config(config) os.environ["TFDS_DATA_DIR"] = config.dataset_path diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index 08a33462f..ee9c3d627 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -140,8 +140,7 @@ def main(argv: Sequence[str]) -> None: print("Starting train_compile.py...", flush=True) # Parse and validate configuration - pyconfig.initialize(argv) - config = pyconfig.config + config = pyconfig.initialize(argv) validate_config(config) # Create target mesh diff --git a/benchmarks/mmlu/mmlu_eval.py b/benchmarks/mmlu/mmlu_eval.py index 399ccd0e0..257608289 100644 --- a/benchmarks/mmlu/mmlu_eval.py +++ b/benchmarks/mmlu/mmlu_eval.py @@ -217,8 +217,7 @@ def validate_config(config): if __name__ == "__main__": jax.config.update("jax_default_prng_impl", "unsafe_rbg") flags.FLAGS(sys.argv) - pyconfig.initialize(sys.argv) - cfg = pyconfig.config + cfg = pyconfig.initialize(sys.argv) validate_config(cfg) max_utils.print_system_information() main(cfg) diff --git a/constraints_gpu.txt b/constraints_gpu.txt index 9327f0645..fe02195e5 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -208,4 +208,5 @@ Werkzeug==2.0.3 wrapt==1.16.0 xxhash==3.5.0 yarl==1.16.0 -zipp==3.20.2 \ No newline at end of file +zipp==3.20.2 +omegaconf \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index aeea52445..9bfb9efa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream jsonlines pathwaysutils@git+https://github.com/google/pathways-utils.git +omegaconf \ No newline at end of file diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index 819fa00f9..329b60568 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -22,4 +22,5 @@ jsonlines pathwaysutils@git+https://github.com/google/pathways-utils.git google-cloud-monitoring google-api-core -google-api-python-client \ No newline at end of file +google-api-python-client +omegaconf \ No newline at end of file