-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Distributed(Tensor Parallel) Inference Recipe #2245
Changes from 27 commits
b8c96a8
9fd582e
975ec47
240d454
f7615a1
f746048
a318544
832d60a
7304ea5
fcb36a5
3f2d6ce
41941a9
8ff6c95
1c7b394
5fd02e6
345b350
129c844
19100d2
3442bbe
63f0423
04c18b3
835cbc7
cc7ece5
a929e66
1bc1b4d
5ad117b
68aee31
a80b7e5
510944e
f14655d
5b36960
7f37b6b
9db97a0
1ad2f76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output | ||
# using a Llama3 70B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching: | ||
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN> | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3/70B_generation_distributed | ||
|
||
output_dir: ./ | ||
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama3.llama3_70b | ||
|
||
parallelize_plan: | ||
_component_: torchtune.models.llama3.base_llama_tp_plan | ||
|
||
# Transform arguments | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model | ||
prompt_template: null | ||
max_seq_len: 8192 | ||
|
||
# Checkpointer | ||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: "00030" | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir} | ||
model_type: LLAMA3 | ||
|
||
# Device | ||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
log_level: INFO | ||
|
||
# Generation arguments | ||
prompt: | ||
system: null | ||
user: | ||
text: Tell a joke. | ||
max_new_tokens: 200 | ||
temperature: 0.6 # 0.8 and 0.6 are popular values to try | ||
top_k: 300 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output | ||
# using a Llama3.1 70B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching: | ||
# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN> | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_1/70B_generation_distributed | ||
|
||
output_dir: ./ | ||
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama3_1.llama3_1_70b | ||
|
||
parallelize_plan: | ||
_component_: torchtune.models.llama3.base_llama_tp_plan | ||
|
||
# Transform arguments | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model | ||
prompt_template: null | ||
max_seq_len: 8192 | ||
|
||
# Checkpointer | ||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: "00030" | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir} | ||
model_type: LLAMA3 | ||
|
||
# Device | ||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
log_level: INFO | ||
|
||
# Generation arguments | ||
prompt: | ||
system: null | ||
user: | ||
text: Tell a joke. | ||
max_new_tokens: 200 | ||
temperature: 0.6 # 0.8 and 0.6 are popular values to try | ||
top_k: 300 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Config for running the InferenceRecipe in dev/generate_v2.py to generate output | ||
# using a Llama3.1 70B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching: | ||
# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --hf-token <HF_TOKEN> | ||
# | ||
# To launch, run the following command from root torchtune directory: | ||
# tune run --nproc_per_node 8 dev/generate_v2_distributed --config llama3_3/70B_generation_distributed | ||
|
||
output_dir: ./ | ||
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama3_3.llama3_3_70b | ||
|
||
parallelize_plan: | ||
_component_: torchtune.models.llama3.base_llama_tp_plan | ||
|
||
# Transform arguments | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model | ||
prompt_template: null | ||
max_seq_len: 8192 | ||
|
||
# Checkpointer | ||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: "00030" | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir} | ||
model_type: LLAMA3 | ||
|
||
# Device | ||
device: cuda | ||
dtype: bf16 | ||
seed: 1234 | ||
log_level: INFO | ||
|
||
# Generation arguments | ||
prompt: | ||
system: null | ||
user: | ||
text: Tell a joke. | ||
max_new_tokens: 200 | ||
temperature: 0.6 # 0.8 and 0.6 are popular values to try | ||
top_k: 300 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,18 +39,22 @@ def __call__(self, prompt: Dict[str, Any]) -> List[Message]: | |
|
||
# Iterate through roles and add content | ||
for role, content in prompt.items(): | ||
if isinstance(content, str): | ||
if content is None: | ||
continue | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can it ever be None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I put something like
the system prompt could potentially be None, so we should probably handle it? |
||
elif isinstance(content, str): | ||
new_content = [{"type": "text", "content": content}] | ||
else: | ||
assert ( | ||
"image" in content.keys() | ||
), "Multiple entries per role expect an image key" | ||
elif "image" in content.keys(): | ||
image_loc = content["image"] | ||
image = load_image(image_loc) | ||
new_content = [ | ||
{"type": "image", "content": image}, | ||
{"type": "text", "content": content["text"]}, | ||
] | ||
else: | ||
assert ( | ||
"text" in content.keys() | ||
), "Multiple entries per role expect at least a text key" | ||
new_content = [{"type": "text", "content": content["text"]}] | ||
messages.append(Message(role=role, content=new_content)) | ||
|
||
# Finally, add an empty assistant message to kick-start generation | ||
|
@@ -109,10 +113,10 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None: | |
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec" | ||
) | ||
self._logger.info( | ||
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | ||
f"Bandwidth achieved: {model_size * tokens_per_second / 1024 / 1024 / 1024:.02f} GiB/s" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit nit nit: 1024 ** 3 |
||
) | ||
self._logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024:.02f} GiB" | ||
) | ||
|
||
@torch.inference_mode() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be adding a config for llama 3.3?