-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Distributed(Tensor Parallel) Inference Recipe #2245
Added Distributed(Tensor Parallel) Inference Recipe #2245
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2245
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit 1ad2f76 with merge base 7747db1 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…/torchtune into distributed_inference
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2245 +/- ##
===========================================
- Coverage 64.30% 23.95% -40.35%
===========================================
Files 352 357 +5
Lines 20566 21174 +608
===========================================
- Hits 13225 5073 -8152
- Misses 7341 16101 +8760 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor comments/question
torchtune/training/_distributed.py
Outdated
@@ -45,6 +48,18 @@ | |||
"dev" not in torch_version and torch_version_ge("2.6.0") | |||
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") | |||
|
|||
BASE_LLAMA_TP_PLAN = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need one for each family of models? If so, is this file the right place to store it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am also curious what's the best place to store this info. The plan should be shared within llama3, 3.1, and 3.2, but we should define unique plans for 3.2 vision and 4. Maybe it should be stored in _model_builders.py
? What's a better place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. I feel that _model_builders.py would be too scattered. If we had training/distributed/, i would put it there. Does TorchTitan have something like this for multiple models? maybe we could check how they do it.
Every model has a checkpoint mapping torchtune <-> hf. How do we handle it? Probably we should follow the same pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torchtitan didn't define plan for each model, they just have one apply_tp
function for llama3 in a parallelize_llama.py
.
I saw each model has a convert_weights.py
file for converting weights format between hf and torchtune. Maybe let me create a parallelism file for each model, to put all the plans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a parallelism file for every model. The vast majority of models that we support will be able to fall under the BASE_LLAMA_TP_PLAN. It doesn't have to live in training/ but it should live in somewhere centralized. Then, if there is a specific TP plan that we want to enable for, say, LLama3.2V, then we can define it either in the _model_builders.py file OR we can add a _parallelism.py file under the model directory where we define the TP plan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call it TRANSFORMER_DECODER_TP_PLAN or similar so it's not llama-specific. maybe we'll finally need a distributed folder 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd vote for something like BASIC_TP_PLAN
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BASED_TP_PLAN
(this is a joke)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hahahaha ... unless? 👀
torchtune/training/_distributed.py
Outdated
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()), | ||
"output": ColwiseParallel(output_layouts=Replicate()), | ||
"layers.*.attn.q_proj": ColwiseParallel(), | ||
"layers.*.attn.k_proj": ColwiseParallel(), | ||
"layers.*.attn.v_proj": ColwiseParallel(), | ||
"layers.*.attn.output_proj": RowwiseParallel(), | ||
"layers.*.mlp.w1": ColwiseParallel(), | ||
"layers.*.mlp.w2": RowwiseParallel(), | ||
"layers.*.mlp.w3": ColwiseParallel(), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: is this row/col the optimal setup? or is it somewhat arbitrary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For matrix multiplication, we just need to make sure one matrix is Col
and the other is Row
. For example, because the math is mlp.w2(mlp.w1(x) * mlp.w3(x))
, therefore we just need to make sure that w1 and w3 are col
and w2 is row
, or the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional: Maybe adding this comment on top of it would be good
torchtune/training/_distributed.py
Outdated
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]: | ||
""" | ||
Get the TP plan for a given model type. | ||
|
||
Args: | ||
model_type (str): The model type to get the TP plan for. | ||
|
||
Returns: | ||
Dict[str, str]: A dictionary mapping layer names to their corresponding TP plan. | ||
""" | ||
# For now, we only support base TP plan, will add more plan later | ||
return BASE_LLAMA_TP_PLAN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this is a v0, but should we add something like:
if model_type not in LLAMA_MODEL_TYPES:
raise "TP only supported for llama type models"
torchtune/training/_distributed.py
Outdated
Returns: | ||
nn.Module: Adjusted model. | ||
""" | ||
for transformer_block in model.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will break for vision model, since we do model.decoder.layers, unless we call adjust_attention_for_tp(model=model.decoder)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! yeah I had this in my local changes, trying to make vision 3.2 work. I did the following:
model = getattr(model, 'decoder', model)
Let me know if you have better ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not my proudest moment:
torchtune/torchtune/training/_compile.py
Line 46 in 890deab
if isinstance(model, DeepFusionModel): |
torchtune/training/_distributed.py
Outdated
""" | ||
for transformer_block in model.layers: | ||
# Adjust attention module to use the local number of heads | ||
attn_layer = transformer_block.attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is ok, but maybe a more robust option would be to look for the module type == SelfAttentionLayer
Expects the YAML to look like: | ||
system: You are a helpful AI assistant. | ||
user: What is the capital of France? | ||
|
||
or if it includes an image: | ||
system: You are a helpful AI assistant. | ||
user: | ||
image: url or path_to_image | ||
text: Describe the image in detail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should denote that it is a ::codeblock: yaml
, ask some llm for formating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even the strongest LLMs cannot comprehend sphinx rst syntax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw this isn't gonna show up in our live docs anyways, right? In that case I would lean away from Sphinx directives -- if people are just reading the code it'll needlessly clutter things up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(But separately we should think about putting this somewhere besides the recipe file anyways, especially now that we're copying the same class to two different recipes)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
# Set up distributed env | ||
dist.init_process_group("cuda:nccl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have seen in other parts of the code this resulting in errors if we dont do init_process_group("cuda:nccl,cpu:gloo")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did see other file has cpu:gloo
curious why? Since we will not use cpu as backend for inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have a vague memory that there may be some weight that is initialized using cpu, for some reason, and without cpu:gloo, it raises an error. But I dont remember exactly the issue. In any case, I dont think it hurts to add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried adding cpu:gloo
when initializing the process group. But I am getting some RMSNorm cuda cpu device mismatch. With just cuda:nccl
and the exact same code, it works with no problem. https://www.internalfb.com/phabricator/paste/view/P1713611664
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think gloo PG is only relevant for CPU offloading, which we aren't doing today anyways
# Set up tenosr parallel device mesh | ||
tp_degree = dist.get_world_size() # Using all GPUs for TP | ||
tp_mesh_shape = (tp_degree,) | ||
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: should we worry about other device types, e.g. npu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to leave that out in a first pass. Last I knew of we don't yet have distributed support for NPUs anyways (though @noemotiovon can inform me if my info is out of date here)
|
||
# This method will convert the full model state dict into a sharded state | ||
# dict and load into the model | ||
training.load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: as a rule of thumb, i think its worth using key arguments for all args, not only strict and cpu_offload
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | ||
) | ||
self._logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: i think in general we prefer to use GiB, otherwise it may appear that we used more memory than the GPU has available --> to change replace 1e9 with /1024/1204
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of tok/sec are we seeing with TP and Llama3 for the distributed inference recipe?
torchtune/training/_distributed.py
Outdated
@@ -546,3 +564,72 @@ def shard_model( | |||
|
|||
# Finally shard the entire model to account for any stragglers | |||
fully_shard(model, **fsdp_kwargs) | |||
|
|||
|
|||
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually want to avoid something like this. It's very similar to how we did checkpointing where we would have complicated if/else logic that quickly got very confusing.
I'm imagining that the user would be able to pass in their TP plan directly from the config if they want to use tensor parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think just pointing directly to a function in the config with the plan they want gives the users the most flexibility
tensor_parallel_plan:
_component_: torchtune.training.BASE_LLAMA_TP_PLAN
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned that things could become messier if we allow user to directly change TP plan from config, especially if we have more parallelism enabled. In most cases the default TP plan should suffice, it's unclear why users would need to change it. Ideally each model will have default TP plan under _parallelism.py
and if advanced user really want to experiment with plan, they can modify there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if advanced user really want to experiment with plan, they can modify there.
This would necessitate users using git clone
to interact with torchtune, which isn't the case for many of our users. It has to somehow be possible to override the TP plan via some builder or directly in the config. We could just call the value parallelism_plan
, which could extend to other types as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say those are two separate problems - where to place the plans and how to specify them. I agree with placing the plans under each model's _parallelism.py
or some central place for the default, but for accessing them you'll either need to do some weird dictionary mapping or specify it directly from the config. Since all the generation configs are model specific, we can just point to the default for each model.
The direct from config approach just lets you bypass the whole dict mapping, which is cleaner imo. Otherwise you will need to keep updating the mapping for every new model.
torchtune/training/_distributed.py
Outdated
@@ -45,6 +48,18 @@ | |||
"dev" not in torch_version and torch_version_ge("2.6.0") | |||
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") | |||
|
|||
BASE_LLAMA_TP_PLAN = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a parallelism file for every model. The vast majority of models that we support will be able to fall under the BASE_LLAMA_TP_PLAN. It doesn't have to live in training/ but it should live in somewhere centralized. Then, if there is a specific TP plan that we want to enable for, say, LLama3.2V, then we can define it either in the _model_builders.py file OR we can add a _parallelism.py file under the model directory where we define the TP plan.
torchtune/training/_distributed.py
Outdated
tp_mesh: DeviceMesh, | ||
) -> nn.Module: | ||
""" | ||
Adjusts the number of attention heads and dimension in the model to account for tensor parallelism. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This description is too vague. We should communicate what exactly is happening to the users.
torchtune/training/_distributed.py
Outdated
""" | ||
# Consider the case of Early Fusion or Deep Fusion models | ||
if isinstance(model, DeepFusionModel): | ||
model = model.docoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model = model.docoder | |
model = model.decoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think that docoder sounds nicer
torchtune/training/_distributed.py
Outdated
assert attn.num_heads % tp_mesh.size() == 0 | ||
assert attn.num_kv_heads % tp_mesh.size() == 0 | ||
assert attn.embed_dim % tp_mesh.size() == 0 | ||
attn.num_heads = attn.num_heads // tp_mesh.size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Don't need to use the floor division operator if you already determined that the tp_mesh.size() goes evenly into the num_heads, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's expected to be an int, so I used //
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but both num_heads and tp_mesh.size() are ints so it'll always be an int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not in my case :(, I printed out in code, even though both numerator and denominator are int, the result is float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird ... okay sounds good then!
torchtune/training/_distributed.py
Outdated
# Adjust attention module to use the local number of heads | ||
attention_layers = ([layer.attn] if not isinstance(layer, FusionLayer) else [layer.fusion_layer.attn, layer.layer.attn]) | ||
for attn in attention_layers: | ||
assert attn.num_heads % tp_mesh.size() == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May as well pull tp_mesh.size()
call out top.
torchtune/training/_distributed.py
Outdated
raise ValueError("TP is only supported for llama type models right now.") | ||
|
||
|
||
def adjust_attention_for_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be more precise here? Something like scale_attention_heads_by_tp_size
?
The only reason I could think not to do this is if there might be other adjustments to the attention we might need to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It changes number of heads as well as the emb dim. maybe shard_attention_params_for_tp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reviewed and I personally still found shard_attention_params_for_tp
a bit confusing. Really you are just setting num_heads
so that the reshapes are TP-aware, right? If that's the case I would even lean towards calling it prepare_mha_for_tp
or something (the current name makes me thing you are actually distributed the params across devices in this utility, which you are not). Have I mentioned I hate naming things
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) | ||
|
||
# Get TP plan and apply TP | ||
tp_plan = training.get_tp_plan(cfg.checkpointer.model_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not count on model_type
staying around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we should find some other way to reliably get the TP plan depending on the model... maybe this could be a parameter in the config that points to a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really like this, we just need to figure out how we expect users to use our built-in TP plans and if they want to supply their own. and how this can be specified from the config.
Expects the YAML to look like: | ||
system: You are a helpful AI assistant. | ||
user: What is the capital of France? | ||
|
||
or if it includes an image: | ||
system: You are a helpful AI assistant. | ||
user: | ||
image: url or path_to_image | ||
text: Describe the image in detail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even the strongest LLMs cannot comprehend sphinx rst syntax
This *does not* currently support the following features: | ||
- torch.compile | ||
- quantization through torchao | ||
- multi-GPU generation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should update this and call out that this can be run distirbuted for larger models using TP. And could point to some pytorch docs on TP
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
# Set up distributed env | ||
dist.init_process_group("cuda:nccl,cpu:gloo") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a utility for this that also sets the port, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following full_finetune_distributed.py
and seems like it doesn't use any util. Just init_process_group()
with training.set_default_dtype(self._dtype), torch.device("meta"): | ||
model = config.instantiate(cfg.model) | ||
|
||
# Set up tenosr parallel device mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Set up tenosr parallel device mesh | |
# Set up tensor parallel device mesh |
model = config.instantiate(cfg.model) | ||
|
||
# Set up tenosr parallel device mesh | ||
tp_degree = dist.get_world_size() # Using all GPUs for TP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this also work on multinode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't make me worry about multi-node INFERENCE too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we should prevent users from trying multinode somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can drop a comment at the top of the recipe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added into the docstring for the InferenceRecipe
class.
@@ -433,6 +433,17 @@ class Recipe: | |||
], | |||
supports_distributed=False, | |||
), | |||
Recipe( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to add the 3.1 and 3.2 configs here, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added config for 3, 3.1, and 3.2, for 3.2 I am still debugging.
|
||
|
||
# Define the parallelism plan for Llama3.2 vision model | ||
LLAMA_DEEP_FUSION_VISION_TP_PLAN = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just call this LLAMA_3_2_VISION_TP_PLAN
torchtune/training/_distributed.py
Outdated
@@ -45,6 +48,18 @@ | |||
"dev" not in torch_version and torch_version_ge("2.6.0") | |||
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") | |||
|
|||
BASE_LLAMA_TP_PLAN = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call it TRANSFORMER_DECODER_TP_PLAN or similar so it's not llama-specific. maybe we'll finally need a distributed folder 👀
torchtune/training/_distributed.py
Outdated
@@ -546,3 +564,72 @@ def shard_model( | |||
|
|||
# Finally shard the entire model to account for any stragglers | |||
fully_shard(model, **fsdp_kwargs) | |||
|
|||
|
|||
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I think just pointing directly to a function in the config with the plan they want gives the users the most flexibility
tensor_parallel_plan:
_component_: torchtune.training.BASE_LLAMA_TP_PLAN
torchtune/training/_distributed.py
Outdated
for attn in attention_layers: | ||
assert attn.num_heads % tp_mesh.size() == 0 | ||
assert attn.num_kv_heads % tp_mesh.size() == 0 | ||
assert attn.embed_dim % tp_mesh.size() == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use if... raise
instead. It is more descriptive and you can describe what needs to be changed to fix the error
Take ~10 seconds, ~3 tokens/second to inference a single prompt for llama3 and 3.1. |
This reverts commit 3442bbe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great! Really excited to see this landing in the library. Apart from inline comments, main questions are around testing -- it'd be great to update the PR summary to address the following questions.
- Are the generations at parity with the corresponding ones from
generate_v2.py
(obviously not feasible for 70B, but could test on e.g. an 8B model)? - What's the throughput and peak memory? (On a related note, are 8 devices necessary to run 70B inference? If not, can maybe use a smaller number in the config) Edit: I see you answered this in a comment above, still might be good to include comprehensive results in the PR test plan
- Do the TP sharding utilities work with our other model families (Gemma, Mistral, Phi, etc)? Any models that they definitely do not work with? If that's the case we don't have to block on supporting everything, just want to be explicit about that.
recipes/dev/generate_v2.py
Outdated
@@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: | |||
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" | |||
) | |||
self._logger.info( | |||
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | |||
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024:.02f} GB/s" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing something obvious, but shouldn't there be a 3rd 1024 here (and below)?
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN> | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should remove .yaml
from these commands once the configs are added to the recipe registry
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()), | ||
"output": ColwiseParallel(output_layouts=Replicate()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about this comment in the torchchat code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I am curious how they got into this conclusion. For torchtune inference, I tried commenting these two lines out, the inference speed doesn't make too much difference(I ran several times, sometimes it's slower and sometimes faster, all differs by 0.0x seconds.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's good to know. cc @fduwjj as the author of the corresponding torchchat PR in case you have any insights
} | ||
|
||
|
||
def base_llama_tp_plan() -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can values in the return dict be typed as ParallelStyle
? Or are there cases where the plan for a layer might fall outside these classes
"decoder.layers.*.layer.attn.k_proj": ColwiseParallel(), | ||
"decoder.layers.*.layer.attn.v_proj": ColwiseParallel(), | ||
"decoder.layers.*.layer.attn.output_proj": RowwiseParallel(), | ||
"decoder.layers.*.layer.mlp.w1": ColwiseParallel(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we consolidate e.g. decoder.layers.*.layer.mlp.w1
and decoder.layers.*.fusion_layer.mlp.w1
-> decoder.layers.*.mlp.w1
or something like that? (Similar question for other layers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, I think it's the same, the wildcard should support this. But let me test it on 3.2 later, right now i am having problem because some distributed state dict problem with 3.2.
Expects the YAML to look like: | ||
system: You are a helpful AI assistant. | ||
user: What is the capital of France? | ||
|
||
or if it includes an image: | ||
system: You are a helpful AI assistant. | ||
user: | ||
image: url or path_to_image | ||
text: Describe the image in detail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(But separately we should think about putting this somewhere besides the recipe file anyways, especially now that we're copying the same class to two different recipes)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
# Set up distributed env | ||
dist.init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually wanna support this on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably not, let me remove that.
# Set up tenosr parallel device mesh | ||
tp_degree = dist.get_world_size() # Using all GPUs for TP | ||
tp_mesh_shape = (tp_degree,) | ||
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to leave that out in a first pass. Last I knew of we don't yet have distributed support for NPUs anyways (though @noemotiovon can inform me if my info is out of date here)
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024:.02f} GB/s" | ||
) | ||
self._logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024:.02f} GB" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here as in generate_v2.py
torchtune/training/_distributed.py
Outdated
attn.num_heads = attn.num_heads // tp_size | ||
attn.num_kv_heads = attn.num_kv_heads // tp_size | ||
attn.embed_dim = attn.embed_dim // tp_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to make sure I understand the purpose of this utility: the idea is to set the appropriate params on the attention module so that any reshapes etc performed there will result in the correctly-shaped input for sharded Q, K, V, output projections?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also wonder whether we could just iterate over e.g. model.modules()
and check isinstance(m, MultiHeadAttentionLayer)
to avoid any dependency on the actual higher-level model arch details (fusion vs not, etc). For training this might be risky (e.g. AC will wrap modules and then you'd need to decide whether to apply on the AC-wrapped module vs not), but I don't see an obvious case it would break in inference.
torchtune/training/_distributed.py
Outdated
@@ -546,3 +550,71 @@ def shard_model( | |||
|
|||
# Finally shard the entire model to account for any stragglers | |||
fully_shard(model, **fsdp_kwargs) | |||
|
|||
|
|||
def shard_attention_params_for_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
model: | ||
_component_: torchtune.models.llama3.llama3_70b | ||
|
||
tensor_parallel_plan: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might want to mirror the parrelilize_module API and call this parallelize_plan
. See https://pytorch.org/docs/main/distributed.tensor.parallel.html#torch.distributed.tensor.parallel.parallelize_module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incredible work!
@@ -546,3 +548,66 @@ def shard_model( | |||
|
|||
# Finally shard the entire model to account for any stragglers | |||
fully_shard(model, **fsdp_kwargs) | |||
|
|||
|
|||
def prepare_mha_for_tp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably have a test :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added! thanks for the reminder.
recipes/dev/generate_v2.py
Outdated
@@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: | |||
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" | |||
) | |||
self._logger.info( | |||
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | |||
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024 / 1024:.02f} GiB/s" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit nit nit: 1024 ** 3
1b4f781
to
a80b7e5
Compare
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Enabled TP for inference, for llama3, 3.1, 3.3 70B text only models. Llama3.2 vision 90B are still work in progress. It's blocked by #2277, will enable it as a follow up PR.
dev/generate_v2.py
and added TP to the recipe. The main change is in__init__
and__setup__
._paralellism.py
file._distritbuted.py
, note that the utilities are for now only designed to work with llama models.load_from_full_model_state_dict
to general parallelism, not only FSDP.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
We see that the max memory allocated is only 18.7 GiB, which indicates we may use less number of GPUs.
Running on 2 Gpus, we see that the max memory increases.
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example