Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update model builders #2282

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Ankur-singh
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (Refactor)

Please link to any issues this PR addresses #2270

Changelog

What are the changes made in this PR?

  • Updated all component_builder to pass nn.ModuleList to TransformerDecoder instead of (layer + num_layers).
  • Updated test for T5 model

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2282

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 19, 2025
@Ankur-singh
Copy link
Contributor Author

@ebsmothers for now, I have only updated the component_builders for all the models. In my local testing, removing this else block

passes all the tests. If this PR looks good, I will:

  1. Update VisionTransformer class and component builders for vision models
  2. Delete the _get_clones function

@RdoubleA
Copy link
Contributor

Just one comment about make sure RoPE is instantiated correctly, but the rest looks good

Comment on lines 48 to 50
self.layers = (
layers if isinstance(layers, nn.ModuleList) else nn.ModuleList(layers)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it's not a big deal since T5 is not yet widely-used in the repo, but fyi we do have to be careful about deprecating num_layers as it can break people. The proper thing to do is continue supporting it for one release and mark as deprecated (similar to what we have here for functions/classes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. I will re-introduce support for num_layers and add deprecated decorator with message "num_layers argument will be deprecated in upcoming release."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers Do you want me to add deprecated decorator to TransformerDecoder class as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ankur-singh actually I wouldn’t worry about adding the decorator to the class. I think we need a separate utility to log only when the to-be-deprecated argument is passed. Otherwise everyone will see the warning about deprecation of num_layers even if they aren’t using it

[0.3383, 0.3150],
[0.3727, 0.2892],
[0.3996, 0.2653],
[0.4958, 0.4845],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that the expected values need to change here. I would have thought that uniform initialization iterating over model.parameters() shouldn't be affected by whether we use _get_clones or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After making all the changes, I tried pytest tests with:

  • with _get_clones: success
  • Without _get_clones: failed

Hence, I thought the difference is because of _get_clones. Ran the following script to test the hypothesis:

import torch
from torchtune.modules.transformer import _get_clones
from torchtune.modules.peft import LoRALinear

def main():
    loras_loop = [None] * 4
    for i in range(4):
        loras_loop[i] = LoRALinear(in_dim=16, out_dim=16, rank=4, alpha=1.0)

    loras_cloned = _get_clones(
        LoRALinear(in_dim=16, out_dim=16, rank=4, alpha=1.0), 4
    )

    loop_max_diff = torch.max(torch.abs(loras_loop[0].lora_a.weight - loras_loop[3].lora_a.weight))
    cloned_max_diff = torch.max(torch.abs(loras_cloned[0].lora_a.weight - loras_cloned[3].lora_a.weight))

    print(f"Max diff between layers using for-loop: {loop_max_diff}")
    print(f"Max diff between layers using _get_clones: {cloned_max_diff}")

    input = torch.randn(1, 16)
    
    output1 = input.clone()
    for layer in loras_loop:
        output1 = layer(output1)

    output2 = input.clone()
    for layer in loras_cloned:
        output2 = layer(output2)

    cloned_max_diff = torch.max(torch.abs(output1 - output2))
    print(f"Max diff between outputs from two approach: {cloned_max_diff}")


if __name__ == "__main__":
    main()

# ----
# Max diff between layers using for-loop: 0.4660979211330414
# Max diff between layers using _get_clones: 0.0
# Max diff between outputs from two approach: 0.3515825569629669

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I suspect this is due to something with the random seed. Let me take a closer look to confirm but otherwise I think updating the expected values is fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also dug into this one a bit more: the changes to T5 change the order in which modules are registered, so when we do

for param in model.parameters():
	param.data.uniform_(0, 1)

we again wind up with different state for the rng. So again it's fine to change the expected values here

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these changes! I left a few comments (mainly in response to some of @RdoubleA's points about RoPE -- TLDR is don't worry too much for this PR provided everything works). Otherwise no major concerns though

@Ankur-singh
Copy link
Contributor Author

@ebsmothers & @RdoubleA I have made the requested changes.

GPU test fails with the error message:

FAILED tests/recipes/dev/test_generate_v2.py::TestGenerateV2::test_llama2_generate_results - AssertionError: assert 'Country maior Connection Kohćutsójcustomulas Sometimes Security' in 'INFO     torchtune.utils._logging:_utils.py:28 Running InferenceRecipe with resolved config:\n\ncheckpointer:\n  _component_: torchtune.training.FullModelTorchTuneCheckpointer\n  checkpoint_dir: /tmp/test-artifacts\n  checkpoint_files:\n  - /tmp/test-artifacts/small-ckpt-tune-03082024.pt\n  model_type: LLAMA2\n  output_dir: /tmp/pytest-of-ec2-user/pytest-0/test_llama2_generate_results0\ndevice: cpu\ndtype: fp32\nlog_level: INFO\nmax_new_tokens: 10\nmodel:\n  _component_: torchtune.models.llama2.llama2\n  embed_dim: 256\n  max_seq_len: 2048\n  norm_eps: 1.0e-05\n  num_heads: 16\n  num_kv_heads: 8\n  num_layers: 4\n  vocab_size: 32000\noutput_dir: /tmp/pytest-of-ec2-user/pytest-0/test_llama2_generate_results0\nprompt:\n  system: You are a helpful and creative AI assistant.\n  user: What is the capital of France?\nseed: 123\ntemperature: 0.6\ntokenizer:\n  _component_: torchtune.models.llama2.llama2_tokenizer\n  max_seq_len: 2048\n  path: /tmp/test-artifacts/tokenizer.model\ntop_k: 300\n\nINFO     torchtune.utils._logging:generate_v2.py:94 Model was initialized with precision torch.float32.\nINFO     torchtune.utils._logging:generate_v2.py:208 \n\nPietroместkap щotimes rivers cache НиtringindexPathNAME\n\nINFO     torchtune.utils._logging:generate_v2.py:112 Time for inference: 0.08 sec total, 135.68 tokens/sec\nINFO     torchtune.utils._logging:generate_v2.py:115 Bandwidth achieved: 9.92 GiB/s\n'
====== 1 failed, 761 passed, 7 skipped, 14 warnings in 2548.47s (0:42:28) ======

It's because of

I'm assuming it's because of initialization (refer #2282 (comment)). I had to update the out tensors for T5 encoder as well. I think we will have to take a closer look sometime in future.

@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
@ebsmothers
Copy link
Contributor

Hi @Ankur-singh thanks for your patience, I wanted to sanity check the failing test to make sure it's expected. In our generate_v2 recipe we set the seed during init here. This means we construct the model after setting seed, and so when you change how we initialize the model (i.e. we now do some random initialization for every single layer, instead of just one layer that we then copy), the RNG state will be in a different place by the time we get to our call to generate, hence our sampling for generation will yield different results. You can confirm this by moving the call to training.set_seed e.g. here to force model initialization to happen first -- you will see that whether you run on main or on this PR you will get the same result.

Anyways TLDR for this PR is that you are not breaking anything, you can safely just update the expected value for this test. Separately we can think about whether it's clearer to call set_seed at the beginning of generate to ensure that we get deterministic behavior irrespective of model initialization. cc @joecummings in case he has any thoughts on this.

@Ankur-singh
Copy link
Contributor Author

Ankur-singh commented Jan 23, 2025

Hi @ebsmothers thanks for clarifying. I'm a bit confused, after initializing the model, we load the state_dict from checkpoint. This should basically overwrite any previous initialization, right? Furthermore, during testing, we are loading some dummy model weights. This should make the model weights deterministic, irrespective of the initialization method and random_seed. Is my understanding correct so far?

random_seed only comes to picture in generate method when sampling the next token. So at least during testing, as long as we have the same random_seed and set_seed is called before the generation starts, we should be getting the same output. Only thing that could lead to different generation would be change in checkpoint weights.

I also tried out this small script to see if calling a random operation before setting the seed affects the outcome (looks like it does not):

 import torch

# Case 1: No random operations before setting seed
torch.manual_seed(42)
x1 = torch.randn(3)
print("Case 1:", x1)

_ = torch.randn(5)  # Consumes RNG state
_ = torch.randn(5) 

# Case 2: Random operation before setting seed
torch.manual_seed(42)
x2 = torch.randn(3)
print("Case 2:", x2)
print(f"Case 1 and Case 2 are equal: {torch.allclose(x1, x2)}")

-----
# Output:
# Case 1: tensor([0.3367, 0.1288, 0.2345])
# Case 2: tensor([0.3367, 0.1288, 0.2345])
# Case 1 and Case 2 are equal: True

TLDR, as long as we are loading the model weight from same checkpoint and calling set_seed before start of generation (either in __init__ or generate method) we should be getting the same output. Am I missing something here?

PS: replacing the text here

"Country maior Connection Kohćutsójcustomulas Sometimes Security"
with "Pietroместkap щotimes rivers cache НиtringindexPathNAME" gets a green light. Tested it by running pytest tests/recipes/dev/test_generate_v2.py --ignore tests/torchtune/modules/_export --with-integration --durations=20 -vv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants