Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SegmentID when doing data prallel SPMD #8425

Merged
merged 3 commits into from
Nov 28, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Nov 27, 2024

this is built on top of #8333

When sharding spec is provided, we also need to shard the segment ID. The data parallel case is the easiest one.

Q: [B, num_head, Q_S,  head_dim]
K/V: [B, num_kv_head, KV_S,  head_dim]

Q_segment_id: [B, Q_S]
K/V_segment_id: [B, KV_S]

in the data parallel(or fsdp in this manner since we will do a all_gather on all parameters which make parameter full), the mesh is 1D like (num_device, ), name=("data") and the sharding spec we passed to flash_attention will be ("data", None, None, None). We just need to shard the segment_id the same way.

The tricky part is what do we save for the backward. I think we need to save the sharded segment_ids. You can imagine that after the enable_manual_sharding all of the computation becomes based on local shape. segment_ids is not the output of the flash_attnetion hence we don't have to bring it back to full. We saved the full_q/k/v but we also used enable_manual_sharding to shard it again.

Note that another tricky part is that q_segment_id is not what we passed to the pallas kernel, we actually add one dimension to it. check

q_segment_ids = q_segment_ids.unsqueeze(-1).expand(
[-1 for _ in q_segment_ids.shape] + [FlashAttention.NUM_LANES])
kv_segment_ids = kv_segment_ids.unsqueeze(1).expand([
kv_segment_ids.shape[0], FlashAttention.NUM_SUBLANES,
kv_segment_ids.shape[1]
])
for more details. In this pr I also rename the 3d tensor to q_segment_ids_fa to make it more clear.

@JackCaoG JackCaoG added the tpuci label Nov 27, 2024
@JackCaoG JackCaoG force-pushed the JackCaoG/enable_segmentid_spmd_data_parallel branch from 165da7b to c4db598 Compare November 28, 2024 00:04
@JackCaoG JackCaoG marked this pull request as ready for review November 28, 2024 09:06
@JackCaoG JackCaoG merged commit 1c91219 into master Nov 28, 2024
12 checks passed
rpsilva-aws pushed a commit to rpsilva-aws/xla that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants