Skip to content

Commit

Permalink
#0: Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jan 20, 2025
1 parent 1967a64 commit 5dfa882
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 62 deletions.
75 changes: 16 additions & 59 deletions models/demos/llama3/tt/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
get_max_prefill_chunk_size,
)

from time import time


class LlamaGenerator:
def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None):
Expand Down Expand Up @@ -216,28 +214,9 @@ def _decode_forward_no_trace_text(
rot_mats=tt_rot_mats,
page_table=tt_page_table,
kv_cache=kv_cache,
argmax_on_device=argmax_on_device,
)

# Gather the output across all devices and untilize the tensor (for argmax)
if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_logits = ttnn.all_gather(
tt_logits,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_logits = ttnn.untilize(tt_logits, use_multicore=True)

if argmax_on_device:
tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1
tt_logits, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # ,output_tensor=tokens
)

return tt_logits

def _capture_trace_text(
Expand Down Expand Up @@ -265,26 +244,9 @@ def _capture_trace_text(

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache)

if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_out_trace = ttnn.all_gather(
tt_out_trace,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_out_trace = ttnn.all_gather(tt_out_trace, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out_trace = ttnn.untilize(tt_out_trace, use_multicore=True)

if argmax_on_device:
tt_out_trace = ttnn.argmax( # TODO Add multicore support to batch > 1
tt_out_trace, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # , output_tensor=tokens
)
tt_out_trace = self.model.ttnn_decode_forward(
*transformed_inputs, kv_cache=kv_cache, argmax_on_device=argmax_on_device
)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")
Expand Down Expand Up @@ -474,6 +436,7 @@ def decode_forward(
cross_page_table=None,
enable_trace=True,
read_from_device=True,
argmax_on_device=False,
):
decode_kwargs = {
"position_id": start_pos,
Expand All @@ -484,6 +447,7 @@ def decode_forward(
"page_table": page_table,
"kv_cache": kv_cache,
"cross_page_table": cross_page_table,
"argmax_on_device": argmax_on_device,
}
if enable_trace:
tt_logits = self._easy_trace(**decode_kwargs)
Expand All @@ -509,6 +473,7 @@ def _decode_forward_no_trace(
page_table=None,
kv_cache=None,
cross_page_table=None,
argmax_on_device=False,
):
"""
Performs text decode step.
Expand Down Expand Up @@ -547,22 +512,9 @@ def _decode_forward_no_trace(
page_table=tt_page_table,
kv_cache=kv_cache,
cross_page_table=tt_cross_page_table,
argmax_on_device=argmax_on_device,
)

if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_logits = ttnn.all_gather(
tt_logits,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_logits = ttnn.untilize(tt_logits, use_multicore=True)

return tt_logits

def _capture_trace(
Expand All @@ -575,6 +527,7 @@ def _capture_trace(
page_table=None,
kv_cache=None,
cross_page_table=None,
argmax_on_device=False,
):
"""
Captures a trace for the decode_forward method.
Expand Down Expand Up @@ -607,6 +560,7 @@ def _capture_trace(
page_table=tt_page_table,
kv_cache=kv_cache,
cross_page_table=tt_cross_page_table,
argmax_on_device=argmax_on_device,
)
logger.info("Done Compiling Model")

Expand Down Expand Up @@ -676,6 +630,7 @@ def _capture_trace(
page_table=tt_page_table,
kv_cache=kv_cache,
cross_page_table=tt_cross_page_table,
argmax_on_device=argmax_on_device,
)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
Expand Down Expand Up @@ -752,7 +707,7 @@ def _decode_forward_trace(
),
)

ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=True)
ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False)

return trace_logits_rm

Expand All @@ -766,6 +721,7 @@ def _easy_trace(
page_table=None,
kv_cache=None,
cross_page_table=None,
argmax_on_device=False,
):
"""
Tracing is easy! Just call this method and we'll handle tracing for you.
Expand All @@ -790,6 +746,7 @@ def _easy_trace(
page_table=page_table,
kv_cache=kv_cache,
cross_page_table=cross_page_table,
argmax_on_device=argmax_on_device,
)
self.trace_id = trace_id
self.trace_inputs = {
Expand Down Expand Up @@ -901,8 +858,8 @@ def chat_completion(
top_p: float = 0.9,
max_gen_len=None,
):
# if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len:
# max_gen_len = self.model.configuration.max_seq_len - 1
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len:
max_gen_len = self.model.configuration.max_seq_len - 1

tokens = []

Expand Down
28 changes: 25 additions & 3 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def process_output_decode(self, tt_out, B, S=1, argmax_on_device=False):
"""
if argmax_on_device:
tt_out = ttnn.to_torch(
# tt_out.cpu(blocking=True, cq_id=1),
tt_out,
tt_out, # tt_out.cpu(blocking=True, cq_id=1),
mesh_composer=ttnn.ConcatMesh2dToTensor(
self.mesh_device,
dims=(3, 1) if self.args.is_galaxy else (1, -1),
Expand Down Expand Up @@ -300,12 +299,13 @@ def ttnn_decode_forward(
rot_mats,
page_table=None,
kv_cache=None,
argmax_on_device=False,
):
"""
This method will take device tensors and any other args to run forward.
It returns ttnn device tensors.
"""
return self.forward(
tt_logits = self.forward(
x,
current_pos,
rot_mats=rot_mats,
Expand All @@ -314,6 +314,28 @@ def ttnn_decode_forward(
kv_cache=kv_cache,
)

# Gather the output across all devices and untilize the tensor (for argmax)
if self.args.num_devices > 1:
if self.args.is_galaxy:
tt_logits = ttnn.all_gather(
tt_logits,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_logits = ttnn.untilize(tt_logits, use_multicore=True)

if argmax_on_device:
tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1
tt_logits, dim=3, use_multicore=False if self.args.max_batch_size > 1 else True # ,output_tensor=tokens
)

return tt_logits

def forward(
self,
x: ttnn.Tensor,
Expand Down

0 comments on commit 5dfa882

Please sign in to comment.