Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Distributed(Tensor Parallel) Inference Recipe #2245

Merged
merged 34 commits into from
Jan 18, 2025

Conversation

acisseJZhong
Copy link
Contributor

@acisseJZhong acisseJZhong commented Jan 10, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
Enabled TP for inference, for llama3, 3.1, 3.3 70B text only models. Llama3.2 vision 90B are still work in progress. It's blocked by #2277, will enable it as a follow up PR.

  • Copied dev/generate_v2.py and added TP to the recipe. The main change is in __init__ and __setup__.
  • Added TP plan for llama3, under the model folders _paralellism.py file.
  • Added TP utilities in _distritbuted.py, note that the utilities are for now only designed to work with llama models.
  • Added distributed inference config for llama3 70B and 3.1 70B, and 3.3 70B.
  • Generalize load_from_full_model_state_dict to general parallelism, not only FSDP.
  • Fixed a few typos.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
  • Compared distributed inference results for llama3 8b with non distributed inference. Generations are at parity, below are the screenshot results for distributed inference and non-distributed inference. For distributed, it takes longer time but has lower peak memory.
image image
  • Running llama3 70B distributed inference on 8 Gpus:
Time for inference: 9.03 sec total, 3.10 tokens/sec
Bandwidth achieved: 408.57 GiB/s
Max memory allocated: 18.70 GiB

We see that the max memory allocated is only 18.7 GiB, which indicates we may use less number of GPUs.
Running on 2 Gpus, we see that the max memory increases.

Time for inference: 10.42 sec total, 3.17 tokens/sec
Bandwidth achieved: 417.21 GiB/s
Max memory allocated: 67.98 GiB

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2245

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit 1ad2f76 with merge base 7747db1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 10, 2025
@acisseJZhong acisseJZhong changed the title [TP] Added Distributed Inference Recipe Added Distributed(Tensor Parallel) Inference Recipe Jan 10, 2025
@codecov-commenter
Copy link

codecov-commenter commented Jan 10, 2025

Codecov Report

Attention: Patch coverage is 28.94737% with 27 lines in your changes missing coverage. Please review.

Project coverage is 23.95%. Comparing base (baae232) to head (41941a9).
Report is 15 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 25.00% 24 Missing ⚠️
torchtune/modules/model_fusion/_fusion_layers.py 0.00% 2 Missing ⚠️
torchtune/modules/attention.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (baae232) and HEAD (41941a9). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (baae232) HEAD (41941a9)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2245       +/-   ##
===========================================
- Coverage   64.30%   23.95%   -40.35%     
===========================================
  Files         352      357        +5     
  Lines       20566    21174      +608     
===========================================
- Hits        13225     5073     -8152     
- Misses       7341    16101     +8760     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@acisseJZhong acisseJZhong requested review from RdoubleA, ebsmothers and joecummings and removed request for RdoubleA and ebsmothers January 10, 2025 07:21
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor comments/question

@@ -45,6 +48,18 @@
"dev" not in torch_version and torch_version_ge("2.6.0")
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")

BASE_LLAMA_TP_PLAN = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need one for each family of models? If so, is this file the right place to store it?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am also curious what's the best place to store this info. The plan should be shared within llama3, 3.1, and 3.2, but we should define unique plans for 3.2 vision and 4. Maybe it should be stored in _model_builders.py? What's a better place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. I feel that _model_builders.py would be too scattered. If we had training/distributed/, i would put it there. Does TorchTitan have something like this for multiple models? maybe we could check how they do it.

Every model has a checkpoint mapping torchtune <-> hf. How do we handle it? Probably we should follow the same pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torchtitan didn't define plan for each model, they just have one apply_tp function for llama3 in a parallelize_llama.py.

I saw each model has a convert_weights.py file for converting weights format between hf and torchtune. Maybe let me create a parallelism file for each model, to put all the plans.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a parallelism file for every model. The vast majority of models that we support will be able to fall under the BASE_LLAMA_TP_PLAN. It doesn't have to live in training/ but it should live in somewhere centralized. Then, if there is a specific TP plan that we want to enable for, say, LLama3.2V, then we can define it either in the _model_builders.py file OR we can add a _parallelism.py file under the model directory where we define the TP plan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call it TRANSFORMER_DECODER_TP_PLAN or similar so it's not llama-specific. maybe we'll finally need a distributed folder 👀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd vote for something like BASIC_TP_PLAN

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BASED_TP_PLAN

(this is a joke)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahahaha ... unless? 👀

Comment on lines 52 to 61
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
"layers.*.attn.q_proj": ColwiseParallel(),
"layers.*.attn.k_proj": ColwiseParallel(),
"layers.*.attn.v_proj": ColwiseParallel(),
"layers.*.attn.output_proj": RowwiseParallel(),
"layers.*.mlp.w1": ColwiseParallel(),
"layers.*.mlp.w2": RowwiseParallel(),
"layers.*.mlp.w3": ColwiseParallel(),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: is this row/col the optimal setup? or is it somewhat arbitrary?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For matrix multiplication, we just need to make sure one matrix is Col and the other is Row. For example, because the math is mlp.w2(mlp.w1(x) * mlp.w3(x)), therefore we just need to make sure that w1 and w3 are col and w2 is row, or the other way around.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: Maybe adding this comment on top of it would be good

Comment on lines 566 to 577
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]:
"""
Get the TP plan for a given model type.

Args:
model_type (str): The model type to get the TP plan for.

Returns:
Dict[str, str]: A dictionary mapping layer names to their corresponding TP plan.
"""
# For now, we only support base TP plan, will add more plan later
return BASE_LLAMA_TP_PLAN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is a v0, but should we add something like:

if model_type not in LLAMA_MODEL_TYPES:
	raise "TP only supported for llama type models"

Returns:
nn.Module: Adjusted model.
"""
for transformer_block in model.layers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will break for vision model, since we do model.decoder.layers, unless we call adjust_attention_for_tp(model=model.decoder)

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! yeah I had this in my local changes, trying to make vision 3.2 work. I did the following:

model = getattr(model, 'decoder', model)

Let me know if you have better ideas.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not my proudest moment:

if isinstance(model, DeepFusionModel):

"""
for transformer_block in model.layers:
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is ok, but maybe a more robust option would be to look for the module type == SelfAttentionLayer

Comment on lines +28 to +36
Expects the YAML to look like:
system: You are a helpful AI assistant.
user: What is the capital of France?

or if it includes an image:
system: You are a helpful AI assistant.
user:
image: url or path_to_image
text: Describe the image in detail.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should denote that it is a ::codeblock: yaml, ask some llm for formating

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even the strongest LLMs cannot comprehend sphinx rst syntax

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw this isn't gonna show up in our live docs anyways, right? In that case I would lean away from Sphinx directives -- if people are just reading the code it'll needlessly clutter things up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But separately we should think about putting this somewhere besides the recipe file anyways, especially now that we're copying the same class to two different recipes)

self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
# Set up distributed env
dist.init_process_group("cuda:nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have seen in other parts of the code this resulting in errors if we dont do init_process_group("cuda:nccl,cpu:gloo")

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did see other file has cpu:gloo curious why? Since we will not use cpu as backend for inference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have a vague memory that there may be some weight that is initialized using cpu, for some reason, and without cpu:gloo, it raises an error. But I dont remember exactly the issue. In any case, I dont think it hurts to add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding cpu:gloo when initializing the process group. But I am getting some RMSNorm cuda cpu device mismatch. With just cuda:nccl and the exact same code, it works with no problem. https://www.internalfb.com/phabricator/paste/view/P1713611664

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think gloo PG is only relevant for CPU offloading, which we aren't doing today anyways

# Set up tenosr parallel device mesh
tp_degree = dist.get_world_size() # Using all GPUs for TP
tp_mesh_shape = (tp_degree,)
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: should we worry about other device types, e.g. npu?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to leave that out in a first pass. Last I knew of we don't yet have distributed support for NPUs anyways (though @noemotiovon can inform me if my info is out of date here)


# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
Copy link
Contributor

@felipemello1 felipemello1 Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: as a rule of thumb, i think its worth using key arguments for all args, not only strict and cpu_offload

Comment on lines 151 to 154
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think in general we prefer to use GiB, otherwise it may appear that we used more memory than the GPU has available --> to change replace 1e9 with /1024/1204

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of tok/sec are we seeing with TP and Llama3 for the distributed inference recipe?

@@ -546,3 +564,72 @@ def shard_model(

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)


def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually want to avoid something like this. It's very similar to how we did checkpointing where we would have complicated if/else logic that quickly got very confusing.

I'm imagining that the user would be able to pass in their TP plan directly from the config if they want to use tensor parallel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think just pointing directly to a function in the config with the plan they want gives the users the most flexibility

tensor_parallel_plan:
  _component_: torchtune.training.BASE_LLAMA_TP_PLAN

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned that things could become messier if we allow user to directly change TP plan from config, especially if we have more parallelism enabled. In most cases the default TP plan should suffice, it's unclear why users would need to change it. Ideally each model will have default TP plan under _parallelism.py and if advanced user really want to experiment with plan, they can modify there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if advanced user really want to experiment with plan, they can modify there.

This would necessitate users using git clone to interact with torchtune, which isn't the case for many of our users. It has to somehow be possible to override the TP plan via some builder or directly in the config. We could just call the value parallelism_plan, which could extend to other types as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say those are two separate problems - where to place the plans and how to specify them. I agree with placing the plans under each model's _parallelism.py or some central place for the default, but for accessing them you'll either need to do some weird dictionary mapping or specify it directly from the config. Since all the generation configs are model specific, we can just point to the default for each model.

The direct from config approach just lets you bypass the whole dict mapping, which is cleaner imo. Otherwise you will need to keep updating the mapping for every new model.

@@ -45,6 +48,18 @@
"dev" not in torch_version and torch_version_ge("2.6.0")
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")

BASE_LLAMA_TP_PLAN = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a parallelism file for every model. The vast majority of models that we support will be able to fall under the BASE_LLAMA_TP_PLAN. It doesn't have to live in training/ but it should live in somewhere centralized. Then, if there is a specific TP plan that we want to enable for, say, LLama3.2V, then we can define it either in the _model_builders.py file OR we can add a _parallelism.py file under the model directory where we define the TP plan.

tp_mesh: DeviceMesh,
) -> nn.Module:
"""
Adjusts the number of attention heads and dimension in the model to account for tensor parallelism.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description is too vague. We should communicate what exactly is happening to the users.

"""
# Consider the case of Early Fusion or Deep Fusion models
if isinstance(model, DeepFusionModel):
model = model.docoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = model.docoder
model = model.decoder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think that docoder sounds nicer

assert attn.num_heads % tp_mesh.size() == 0
assert attn.num_kv_heads % tp_mesh.size() == 0
assert attn.embed_dim % tp_mesh.size() == 0
attn.num_heads = attn.num_heads // tp_mesh.size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Don't need to use the floor division operator if you already determined that the tp_mesh.size() goes evenly into the num_heads, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's expected to be an int, so I used //

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but both num_heads and tp_mesh.size() are ints so it'll always be an int.

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not in my case :(, I printed out in code, even though both numerator and denominator are int, the result is float.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird ... okay sounds good then!

# Adjust attention module to use the local number of heads
attention_layers = ([layer.attn] if not isinstance(layer, FusionLayer) else [layer.fusion_layer.attn, layer.layer.attn])
for attn in attention_layers:
assert attn.num_heads % tp_mesh.size() == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May as well pull tp_mesh.size() call out top.

raise ValueError("TP is only supported for llama type models right now.")


def adjust_attention_for_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be more precise here? Something like scale_attention_heads_by_tp_size?

The only reason I could think not to do this is if there might be other adjustments to the attention we might need to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes number of heads as well as the emb dim. maybe shard_attention_params_for_tp?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reviewed and I personally still found shard_attention_params_for_tp a bit confusing. Really you are just setting num_heads so that the reshapes are TP-aware, right? If that's the case I would even lean towards calling it prepare_mha_for_tp or something (the current name makes me thing you are actually distributed the params across devices in this utility, which you are not). Have I mentioned I hate naming things

tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)

# Get TP plan and apply TP
tp_plan = training.get_tp_plan(cfg.checkpointer.model_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not count on model_type staying around.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should find some other way to reliably get the TP plan depending on the model... maybe this could be a parameter in the config that points to a function?

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like this, we just need to figure out how we expect users to use our built-in TP plans and if they want to supply their own. and how this can be specified from the config.

Comment on lines +28 to +36
Expects the YAML to look like:
system: You are a helpful AI assistant.
user: What is the capital of France?

or if it includes an image:
system: You are a helpful AI assistant.
user:
image: url or path_to_image
text: Describe the image in detail.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even the strongest LLMs cannot comprehend sphinx rst syntax

This *does not* currently support the following features:
- torch.compile
- quantization through torchao
- multi-GPU generation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should update this and call out that this can be run distirbuted for larger models using TP. And could point to some pytorch docs on TP

self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
# Set up distributed env
dist.init_process_group("cuda:nccl,cpu:gloo")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a utility for this that also sets the port, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following full_finetune_distributed.py and seems like it doesn't use any util. Just init_process_group()

with training.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg.model)

# Set up tenosr parallel device mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Set up tenosr parallel device mesh
# Set up tensor parallel device mesh

model = config.instantiate(cfg.model)

# Set up tenosr parallel device mesh
tp_degree = dist.get_world_size() # Using all GPUs for TP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this also work on multinode?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't make me worry about multi-node INFERENCE too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we should prevent users from trying multinode somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can drop a comment at the top of the recipe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added into the docstring for the InferenceRecipe class.

@@ -433,6 +433,17 @@ class Recipe:
],
supports_distributed=False,
),
Recipe(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to add the 3.1 and 3.2 configs here, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added config for 3, 3.1, and 3.2, for 3.2 I am still debugging.



# Define the parallelism plan for Llama3.2 vision model
LLAMA_DEEP_FUSION_VISION_TP_PLAN = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just call this LLAMA_3_2_VISION_TP_PLAN

@@ -45,6 +48,18 @@
"dev" not in torch_version and torch_version_ge("2.6.0")
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")

BASE_LLAMA_TP_PLAN = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call it TRANSFORMER_DECODER_TP_PLAN or similar so it's not llama-specific. maybe we'll finally need a distributed folder 👀

@@ -546,3 +564,72 @@ def shard_model(

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)


def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think just pointing directly to a function in the config with the plan they want gives the users the most flexibility

tensor_parallel_plan:
  _component_: torchtune.training.BASE_LLAMA_TP_PLAN

for attn in attention_layers:
assert attn.num_heads % tp_mesh.size() == 0
assert attn.num_kv_heads % tp_mesh.size() == 0
assert attn.embed_dim % tp_mesh.size() == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use if... raise instead. It is more descriptive and you can describe what needs to be changed to fix the error

@acisseJZhong
Copy link
Contributor Author

acisseJZhong commented Jan 16, 2025

What kind of tok/sec are we seeing with TP and Llama3 for the distributed inference recipe?

Take ~10 seconds, ~3 tokens/second to inference a single prompt for llama3 and 3.1.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! Really excited to see this landing in the library. Apart from inline comments, main questions are around testing -- it'd be great to update the PR summary to address the following questions.

  1. Are the generations at parity with the corresponding ones from generate_v2.py (obviously not feasible for 70B, but could test on e.g. an 8B model)?
  2. What's the throughput and peak memory? (On a related note, are 8 devices necessary to run 70B inference? If not, can maybe use a smaller number in the config) Edit: I see you answered this in a comment above, still might be good to include comprehensive results in the PR test plan
  3. Do the TP sharding utilities work with our other model families (Gemma, Mistral, Phi, etc)? Any models that they definitely do not work with? If that's the case we don't have to block on supporting everything, just want to be explicit about that.

@@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024:.02f} GB/s"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something obvious, but shouldn't there be a 3rd 1024 here (and below)?

# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN>
#
# To launch, run the following command from root torchtune directory:
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should remove .yaml from these commands once the configs are added to the recipe registry

Comment on lines +15 to +16
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about this comment in the torchchat code

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I am curious how they got into this conclusion. For torchtune inference, I tried commenting these two lines out, the inference speed doesn't make too much difference(I ran several times, sometimes it's slower and sometimes faster, all differs by 0.0x seconds.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's good to know. cc @fduwjj as the author of the corresponding torchchat PR in case you have any insights

}


def base_llama_tp_plan() -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can values in the return dict be typed as ParallelStyle? Or are there cases where the plan for a layer might fall outside these classes

"decoder.layers.*.layer.attn.k_proj": ColwiseParallel(),
"decoder.layers.*.layer.attn.v_proj": ColwiseParallel(),
"decoder.layers.*.layer.attn.output_proj": RowwiseParallel(),
"decoder.layers.*.layer.mlp.w1": ColwiseParallel(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we consolidate e.g. decoder.layers.*.layer.mlp.w1 and decoder.layers.*.fusion_layer.mlp.w1 -> decoder.layers.*.mlp.w1 or something like that? (Similar question for other layers)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I think it's the same, the wildcard should support this. But let me test it on 3.2 later, right now i am having problem because some distributed state dict problem with 3.2.

Comment on lines +28 to +36
Expects the YAML to look like:
system: You are a helpful AI assistant.
user: What is the capital of France?

or if it includes an image:
system: You are a helpful AI assistant.
user:
image: url or path_to_image
text: Describe the image in detail.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But separately we should think about putting this somewhere besides the recipe file anyways, especially now that we're copying the same class to two different recipes)

self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
# Set up distributed env
dist.init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually wanna support this on CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not, let me remove that.

# Set up tenosr parallel device mesh
tp_degree = dist.get_world_size() # Using all GPUs for TP
tp_mesh_shape = (tp_degree,)
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to leave that out in a first pass. Last I knew of we don't yet have distributed support for NPUs anyways (though @noemotiovon can inform me if my info is out of date here)

Comment on lines 157 to 160
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024:.02f} GB/s"
)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024:.02f} GB"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as in generate_v2.py

Comment on lines 617 to 619
attn.num_heads = attn.num_heads // tp_size
attn.num_kv_heads = attn.num_kv_heads // tp_size
attn.embed_dim = attn.embed_dim // tp_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to make sure I understand the purpose of this utility: the idea is to set the appropriate params on the attention module so that any reshapes etc performed there will result in the correctly-shaped input for sharded Q, K, V, output projections?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also wonder whether we could just iterate over e.g. model.modules() and check isinstance(m, MultiHeadAttentionLayer) to avoid any dependency on the actual higher-level model arch details (fusion vs not, etc). For training this might be risky (e.g. AC will wrap modules and then you'd need to decide whether to apply on the AC-wrapped module vs not), but I don't see an obvious case it would break in inference.

@@ -546,3 +550,71 @@ def shard_model(

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)


def shard_attention_params_for_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

model:
_component_: torchtune.models.llama3.llama3_70b

tensor_parallel_plan:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to mirror the parrelilize_module API and call this parallelize_plan. See https://pytorch.org/docs/main/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.parallelize_module

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incredible work!

@@ -546,3 +548,66 @@ def shard_model(

# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)


def prepare_mha_for_tp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably have a test :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added! thanks for the reminder.

@@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024 / 1024:.02f} GiB/s"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit nit nit: 1024 ** 3

@acisseJZhong acisseJZhong force-pushed the distributed_inference branch from 1b4f781 to a80b7e5 Compare January 17, 2025 22:29
@acisseJZhong acisseJZhong merged commit 779569e into pytorch:main Jan 18, 2025
17 checks passed
@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants