From 262122b00d753c3ea1ad14dde956c601b60cf9ee Mon Sep 17 00:00:00 2001 From: akashc1 <43617927+akashc1@users.noreply.github.com> Date: Fri, 10 Jan 2025 12:02:46 -0800 Subject: [PATCH 1/5] llama 3.1 has correct `max_seq_len` for all versions (#2203) --- torchtune/models/llama3_1/_model_builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index b6439b2eb2..f48ce580f5 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -73,7 +73,7 @@ def llama3_1_405b() -> TransformerDecoder: num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5, @@ -236,7 +236,7 @@ def lora_llama3_1_405b( num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5, From f47f6335a472b0fda436fd9c46ae6b04904cbe49 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Fri, 10 Jan 2025 12:39:15 -0800 Subject: [PATCH 2/5] Log grad norm aggregated over all ranks, not just rank zero (#2248) --- recipes/dev/early_exit_finetune_distributed.py | 2 +- recipes/full_finetune_distributed.py | 2 +- recipes/lora_finetune_distributed.py | 2 +- recipes/lora_finetune_distributed_multi_dataset.py | 2 +- recipes/qat_distributed.py | 2 +- recipes/qat_lora_finetune_distributed.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/recipes/dev/early_exit_finetune_distributed.py b/recipes/dev/early_exit_finetune_distributed.py index 5abc674356..663697e978 100644 --- a/recipes/dev/early_exit_finetune_distributed.py +++ b/recipes/dev/early_exit_finetune_distributed.py @@ -951,7 +951,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9ef5e6533f..4f32faefdb 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -786,7 +786,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 2cdfcd8010..39c8b104e5 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -828,7 +828,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() diff --git a/recipes/lora_finetune_distributed_multi_dataset.py b/recipes/lora_finetune_distributed_multi_dataset.py index ce482bfa27..a50147df8a 100644 --- a/recipes/lora_finetune_distributed_multi_dataset.py +++ b/recipes/lora_finetune_distributed_multi_dataset.py @@ -857,7 +857,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index f1b1302b7d..8c458daa21 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -857,7 +857,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py index 133c39c94b..c742dae226 100644 --- a/recipes/qat_lora_finetune_distributed.py +++ b/recipes/qat_lora_finetune_distributed.py @@ -872,7 +872,7 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ) + ).full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() From c1522489cbcbcaf1f06a350c8ed1b0739d19bab3 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Fri, 10 Jan 2025 14:43:19 -0800 Subject: [PATCH 3/5] Remove example inputs from aoti_compile_and_package Differential Revision: D67998952 Pull Request resolved: https://github.com/pytorch/torchtune/pull/2244 --- .../torchtune/modules/_export/test_export_position_embeddings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/torchtune/modules/_export/test_export_position_embeddings.py b/tests/torchtune/modules/_export/test_export_position_embeddings.py index 6907ca3edd..3beb23e7ef 100644 --- a/tests/torchtune/modules/_export/test_export_position_embeddings.py +++ b/tests/torchtune/modules/_export/test_export_position_embeddings.py @@ -161,7 +161,6 @@ def test_tiled_token_positional_embedding_aoti(self): with tempfile.TemporaryDirectory() as tmpdir: path = torch._inductor.aoti_compile_and_package( tpe_ep, - (self.x, self.aspect_ratio), package_path=os.path.join(tmpdir, "tpe.pt2"), ) tpe_aoti = load_package(path) From dadba25f860d74ac6a3b8541d28d96d6005379b1 Mon Sep 17 00:00:00 2001 From: Insop <1240382+insop@users.noreply.github.com> Date: Fri, 10 Jan 2025 20:38:18 -0800 Subject: [PATCH 4/5] Fix issue #2243, update the document to show correct usage (#2252) --- docs/source/tutorials/llama3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/llama3.rst b/docs/source/tutorials/llama3.rst index 938eafae27..6ceac07e8f 100644 --- a/docs/source/tutorials/llama3.rst +++ b/docs/source/tutorials/llama3.rst @@ -230,7 +230,7 @@ Running generation with our LoRA-finetuned model, we see the following output: .. code-block:: bash tune run generate --config ./custom_generation_config.yaml \ - prompt="Hello, my name is" + prompt.user="Hello, my name is" [generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England. ... From e79ab8bd03c3cd4a0a1cf198de9c42871ecf3629 Mon Sep 17 00:00:00 2001 From: Eugen Hotaj Date: Sun, 12 Jan 2025 11:15:44 -0800 Subject: [PATCH 5/5] [EZ] Fix config bug where interpolation happens too early (#2236) --- tests/torchtune/config/test_config_utils.py | 12 +++++++++--- tests/torchtune/config/test_parse.py | 14 ++++++++++++-- torchtune/config/_parse.py | 2 +- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/config/test_config_utils.py b/tests/torchtune/config/test_config_utils.py index bfce087dbf..b3a2baf063 100644 --- a/tests/torchtune/config/test_config_utils.py +++ b/tests/torchtune/config/test_config_utils.py @@ -28,6 +28,8 @@ }, "d": 4, "f": 8, + "g": "foo", + "h": "${g}/bar", } @@ -50,7 +52,9 @@ def test_get_component_from_path(self): ): _ = _get_component_from_path("torchtune.models.dummy") - @mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG) + @mock.patch( + "torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG) + ) def test_merge_yaml_and_cli_args(self, mock_load): parser = TuneRecipeArgumentParser("test parser") yaml_args, cli_args = parser.parse_known_args( @@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load): "d=6", # Test overriding a flat param "e=7", # Test adding a new param "~f", # Test removing a param + "g=bazz", # Test interpolation happens after override ] ) conf = _merge_yaml_and_cli_args(yaml_args, cli_args) @@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load): assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides." assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides." assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides." + assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides." mock_load.assert_called_once() yaml_args, cli_args = parser.parse_known_args( @@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self): # Test removing non-existent param fails cfg = copy.deepcopy(_CONFIG) - with pytest.raises(KeyError, match="'g'"): - _remove_key_by_dotpath(cfg, "g") + with pytest.raises(KeyError, match="'i'"): + _remove_key_by_dotpath(cfg, "i") diff --git a/tests/torchtune/config/test_parse.py b/tests/torchtune/config/test_parse.py index c4e278acaf..e396b10864 100644 --- a/tests/torchtune/config/test_parse.py +++ b/tests/torchtune/config/test_parse.py @@ -13,7 +13,7 @@ from torchtune.config._parse import TuneRecipeArgumentParser -_CONFIG = {"a": 1, "b": 2} +_CONFIG = {"a": 1, "b": 2, "c": "foo", "d": "${c}/bar"} class TestParse: @@ -41,7 +41,9 @@ def parser(self): parser = TuneRecipeArgumentParser("Test parser") return parser - @patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG) + @patch( + "torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG) + ) def test_parse_known_args(self, mock_load, parser): """ Test that the parser can load a config and override parameters provided on CLI. @@ -65,3 +67,11 @@ def test_parse_known_args(self, mock_load, parser): _ = parser.parse_known_args( ["--config", "test.yaml", "--b", "3"], ) + + # Test that parsing does not prematurely interpolate variables. + config_args, cli_args = parser.parse_known_args( + ["--config", "test.yaml", "c=bazz"] + ) + assert ( + config_args.d == "${c}/bar" + ), f"d == {config_args.d} not ${{c}}/bar as set in config." diff --git a/torchtune/config/_parse.py b/torchtune/config/_parse.py index 5a8e762333..0a29d3be22 100644 --- a/torchtune/config/_parse.py +++ b/torchtune/config/_parse.py @@ -57,7 +57,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]: config = OmegaConf.load(namespace.config) assert "config" not in config, "Cannot use 'config' within a config file" - self.set_defaults(**config) + self.set_defaults(**OmegaConf.to_container(config, resolve=False)) namespace, unknown_args = super().parse_known_args(*args, **kwargs) del namespace.config