Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Distributed(Tensor Parallel) Inference Recipe #2245

Merged
merged 34 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b8c96a8
udpate cuda version
jessicazhongeee Jan 9, 2025
9fd582e
add distributed inference
jessicazhongeee Jan 10, 2025
975ec47
Merge branch 'pytorch:main' into distributed_inference
acisseJZhong Jan 10, 2025
240d454
recover generation config
jessicazhongeee Jan 10, 2025
f7615a1
Merge branch 'distributed_inference' of ssh://github.com/acisseJZhong…
jessicazhongeee Jan 10, 2025
f746048
added configs
jessicazhongeee Jan 10, 2025
a318544
remove 3.2 vision generation config
jessicazhongeee Jan 10, 2025
832d60a
formatting
jessicazhongeee Jan 11, 2025
7304ea5
formatting
jessicazhongeee Jan 15, 2025
fcb36a5
remove imports
jessicazhongeee Jan 15, 2025
3f2d6ce
misc
jessicazhongeee Jan 15, 2025
41941a9
trying to add vision3.2
jessicazhongeee Jan 16, 2025
8ff6c95
address comments
jessicazhongeee Jan 16, 2025
1c7b394
misc
jessicazhongeee Jan 16, 2025
5fd02e6
address comments
jessicazhongeee Jan 16, 2025
345b350
addressed comments
jessicazhongeee Jan 16, 2025
129c844
delete unused functino
jessicazhongeee Jan 16, 2025
19100d2
misc
jessicazhongeee Jan 16, 2025
3442bbe
debugging
jessicazhongeee Jan 17, 2025
63f0423
Revert "debugging"
jessicazhongeee Jan 17, 2025
04c18b3
add llama3.3 config
jessicazhongeee Jan 17, 2025
835cbc7
address commnets
jessicazhongeee Jan 17, 2025
cc7ece5
deubgging
jessicazhongeee Jan 17, 2025
a929e66
address comments
jessicazhongeee Jan 17, 2025
1bc1b4d
remove 3.2 vision
jessicazhongeee Jan 17, 2025
5ad117b
formatting
jessicazhongeee Jan 17, 2025
68aee31
added recipes for registry
jessicazhongeee Jan 17, 2025
a80b7e5
misc
jessicazhongeee Jan 17, 2025
510944e
merge main
jessicazhongeee Jan 17, 2025
f14655d
merge main
jessicazhongeee Jan 17, 2025
5b36960
Merge branch 'main' into distributed_inference
acisseJZhong Jan 17, 2025
7f37b6b
formatting
jessicazhongeee Jan 17, 2025
9db97a0
add tests
jessicazhongeee Jan 18, 2025
1ad2f76
formatting
jessicazhongeee Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion recipes/configs/generation.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Config for running the InferenceRecipe in generate.py to generate output from an LLM
# Config for running the InferenceRecipe in generate.py to generate output
# from Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --ignore-patterns "*.safetensors" --hf-token <HF_TOKEN>
#
# To launch, run the following command from root torchtune directory:
# tune run generate --config generation
Expand Down
50 changes: 50 additions & 0 deletions recipes/configs/llama3/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output
# using a Llama3 70B Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be adding a config for llama 3.3?

#
# To launch, run the following command from root torchtune directory:
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed

output_dir: ./

# Model arguments
model:
_component_: torchtune.models.llama3.llama3_70b

parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model
prompt_template: null
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: null
user:
text: Tell a joke.
max_new_tokens: 200
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
50 changes: 50 additions & 0 deletions recipes/configs/llama3_1/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output
# using a Llama3.1 70B Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN>
#
# To launch, run the following command from root torchtune directory:
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed

output_dir: ./

# Model arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_70b

parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model
prompt_template: null
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: null
user:
text: Tell a joke.
max_new_tokens: 200
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2_vision/11B_generation_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# To launch, run the following command from root torchtune directory:
# tune run dev/generate_v2 --config llama3_2_vision/generation_v2

output_dir: ./ # Not needed
output_dir: ./

# Model arguments
model:
Expand Down
50 changes: 50 additions & 0 deletions recipes/configs/llama3_3/70B_generation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output
# using a Llama3.1 70B Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN>
#
# To launch, run the following command from root torchtune directory:
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed

output_dir: ./

# Model arguments
model:
_component_: torchtune.models.llama3_3.llama3_3_70b

parallelize_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan

# Transform arguments
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
prompt_template: null
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: null
user:
text: Tell a joke.
max_new_tokens: 200
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
18 changes: 11 additions & 7 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ def __call__(self, prompt: Dict[str, Any]) -> List[Message]:

# Iterate through roles and add content
for role, content in prompt.items():
if isinstance(content, str):
if content is None:
continue
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it ever be None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I put something like

system: null 
user:
    text: Tell me a joke. 

the system prompt could potentially be None, so we should probably handle it?

elif isinstance(content, str):
new_content = [{"type": "text", "content": content}]
else:
assert (
"image" in content.keys()
), "Multiple entries per role expect an image key"
elif "image" in content.keys():
image_loc = content["image"]
image = load_image(image_loc)
new_content = [
{"type": "image", "content": image},
{"type": "text", "content": content["text"]},
]
else:
assert (
"text" in content.keys()
), "Multiple entries per role expect at least a text key"
new_content = [{"type": "text", "content": content["text"]}]
messages.append(Message(role=role, content=new_content))

# Finally, add an empty assistant message to kick-start generation
Expand Down Expand Up @@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024 / 1024:.02f} GiB/s"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit nit nit: 1024 ** 3

)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024:.02f} GiB"
)

@torch.inference_mode()
Expand Down
Loading