Skip to content

Commit

Permalink
Merge branch 'main' into no_tile_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Jan 13, 2025
2 parents 33278b1 + e79ab8b commit 8dd77e5
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/llama3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Running generation with our LoRA-finetuned model, we see the following output:
.. code-block:: bash
tune run generate --config ./custom_generation_config.yaml \
prompt="Hello, my name is"
prompt.user="Hello, my name is"
[generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England.
...
Expand Down
2 changes: 1 addition & 1 deletion recipes/dev/early_exit_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down
2 changes: 1 addition & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand Down
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

Expand Down
2 changes: 1 addition & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def train(self) -> None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
).full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand Down
12 changes: 9 additions & 3 deletions tests/torchtune/config/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
},
"d": 4,
"f": 8,
"g": "foo",
"h": "${g}/bar",
}


Expand All @@ -50,7 +52,9 @@ def test_get_component_from_path(self):
):
_ = _get_component_from_path("torchtune.models.dummy")

@mock.patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
@mock.patch(
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
)
def test_merge_yaml_and_cli_args(self, mock_load):
parser = TuneRecipeArgumentParser("test parser")
yaml_args, cli_args = parser.parse_known_args(
Expand All @@ -63,6 +67,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
"d=6", # Test overriding a flat param
"e=7", # Test adding a new param
"~f", # Test removing a param
"g=bazz", # Test interpolation happens after override
]
)
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
Expand All @@ -75,6 +80,7 @@ def test_merge_yaml_and_cli_args(self, mock_load):
assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides."
assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides."
assert "f" not in conf, f"f == {conf.f}, not removed as set in overrides."
assert conf.h == "bazz/bar", f"h == {conf.h}, not bazz/bar as set in overrides."
mock_load.assert_called_once()

yaml_args, cli_args = parser.parse_known_args(
Expand Down Expand Up @@ -185,5 +191,5 @@ def test_remove_key_by_dotpath(self):

# Test removing non-existent param fails
cfg = copy.deepcopy(_CONFIG)
with pytest.raises(KeyError, match="'g'"):
_remove_key_by_dotpath(cfg, "g")
with pytest.raises(KeyError, match="'i'"):
_remove_key_by_dotpath(cfg, "i")
14 changes: 12 additions & 2 deletions tests/torchtune/config/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchtune.config._parse import TuneRecipeArgumentParser

_CONFIG = {"a": 1, "b": 2}
_CONFIG = {"a": 1, "b": 2, "c": "foo", "d": "${c}/bar"}


class TestParse:
Expand Down Expand Up @@ -41,7 +41,9 @@ def parser(self):
parser = TuneRecipeArgumentParser("Test parser")
return parser

@patch("torchtune.config._parse.OmegaConf.load", return_value=_CONFIG)
@patch(
"torchtune.config._parse.OmegaConf.load", return_value=OmegaConf.create(_CONFIG)
)
def test_parse_known_args(self, mock_load, parser):
"""
Test that the parser can load a config and override parameters provided on CLI.
Expand All @@ -65,3 +67,11 @@ def test_parse_known_args(self, mock_load, parser):
_ = parser.parse_known_args(
["--config", "test.yaml", "--b", "3"],
)

# Test that parsing does not prematurely interpolate variables.
config_args, cli_args = parser.parse_known_args(
["--config", "test.yaml", "c=bazz"]
)
assert (
config_args.d == "${c}/bar"
), f"d == {config_args.d} not ${{c}}/bar as set in config."
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_tiled_token_positional_embedding_aoti(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = torch._inductor.aoti_compile_and_package(
tpe_ep,
(self.x, self.aspect_ratio),
package_path=os.path.join(tmpdir, "tpe.pt2"),
)
tpe_aoti = load_package(path)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/config/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]:

config = OmegaConf.load(namespace.config)
assert "config" not in config, "Cannot use 'config' within a config file"
self.set_defaults(**config)
self.set_defaults(**OmegaConf.to_container(config, resolve=False))

namespace, unknown_args = super().parse_known_args(*args, **kwargs)
del namespace.config
Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/llama3_1/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def llama3_1_405b() -> TransformerDecoder:
num_heads=128,
num_kv_heads=8,
embed_dim=16384,
max_seq_len=8192,
max_seq_len=131072,
intermediate_dim=53248,
attn_dropout=0.0,
norm_eps=1e-5,
Expand Down Expand Up @@ -236,7 +236,7 @@ def lora_llama3_1_405b(
num_heads=128,
num_kv_heads=8,
embed_dim=16384,
max_seq_len=8192,
max_seq_len=131072,
intermediate_dim=53248,
attn_dropout=0.0,
norm_eps=1e-5,
Expand Down

0 comments on commit 8dd77e5

Please sign in to comment.