Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RMSNorm by nn.RMSNorm #1464

Merged
merged 2 commits into from
Jan 24, 2025
Merged

Replace RMSNorm by nn.RMSNorm #1464

merged 2 commits into from
Jan 24, 2025

Conversation

manuelcandales
Copy link
Contributor

@manuelcandales manuelcandales commented Jan 16, 2025

In this PR we replace torchchat's own RMSNorm implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301

Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization

Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times.

  • mul is called 275 times per token
  • copy_ is called 202 times per token
  • add is called 97 times per token
  • where is called 34 times per token
  • mean is called 33 times per token
  • rsqrt is called 33 times per token
  • sub is called 32 times per token
  • cat is called 32 times per token
  • stack is called 32 times per token

Currently, torchchat's own RMSNorm operation is basically implemented like this:

norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
output = norm(x.float()).type_as(x) * weight

This means that a single call to torchchat's RMSNorm involves 3 calls to aten::mul and calls to aten::rsqrt, aten::mean and aten::add. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.

Copy link

pytorch-bot bot commented Jan 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1464

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job

As of commit b8801f5 with merge base f4ae60f (image):

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 16, 2025
@Jack-Khuu Jack-Khuu requested review from malfet and Jack-Khuu January 16, 2025 21:30
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!!

Can you add a short snippet of your internal post to the description of this PR? It'll help with OSS visibility

@Jack-Khuu Jack-Khuu changed the title Replace RMSNorm by nn.RMSNorm [WIP - pending pt/pt changes] Replace RMSNorm by nn.RMSNorm Jan 21, 2025
@manuelcandales manuelcandales changed the title [WIP - pending pt/pt changes] Replace RMSNorm by nn.RMSNorm Replace RMSNorm by nn.RMSNorm Jan 24, 2025
@manuelcandales manuelcandales merged commit 42c52bf into main Jan 24, 2025
62 checks passed
vmpuri pushed a commit that referenced this pull request Feb 4, 2025
In this PR we replace torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) implementation by nn.RMSNorm, and we bump the PyTorch pin to capture the massive speed up (30x-40x) to RMSNorm on MPS backend introduced in pytorch/pytorch#145301

Preliminary benchmarks on an M1 Pro with 16GB RAM, show a 33% speed up on token generation when running Llama 3.2 1B with 4-bit quantization

Motivation: Token generation on MPS backend is currently CPU bound, because of MPSGraph overhead. Surprisingly, the ops that are impacting performance the most are simple ones: mul, copy_, add, where, mean, rsqrt, sub, cat, stack. Experiments on an M1 Pro show that each of those op calls on the MPS backend, has at least 20us of CPU overhead. Also, these ops dominate the graph. For example, in aggregate, these ops are called 770 times for each token, when running Llama 3.2 1B. Compare that to SDPA which is called only 33 times, and linear which is called 113 times.
- mul is called 275 times per token
- copy_ is called 202 times per token
- add is called 97 times per token
- where is called 34 times per token
- mean is called 33 times per token
- rsqrt is called 33 times per token
- sub is called 32 times per token
- cat is called 32 times per token
- stack is called 32 times per token

Currently, torchchat's own [RMSNorm](https://github.com/pytorch/torchchat/blob/f4ae60fc936328c7ebd4551019733dc0942c42f9/torchchat/model.py#L931-L942) operation is basically implemented like this:
```
norm = x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
output = norm(x.float()).type_as(x) * weight
```
This means that a single call to torchchat's RMSNorm involves 3 calls to `aten::mul` and calls to `aten::rsqrt`, `aten::mean` and `aten::add`. RMSNorm is called 33 times for each token. Hence, RMSNorm contributes 5 * 33 = 165 of those 770 op calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants