From 5dfa882bfc70978ca0be067f0b6527d50a85f823 Mon Sep 17 00:00:00 2001 From: mtairum Date: Mon, 20 Jan 2025 12:45:43 +0000 Subject: [PATCH] #0: Address review --- models/demos/llama3/tt/generator.py | 75 ++++++--------------------- models/demos/llama3/tt/llama_model.py | 28 ++++++++-- 2 files changed, 41 insertions(+), 62 deletions(-) diff --git a/models/demos/llama3/tt/generator.py b/models/demos/llama3/tt/generator.py index 5a31f9c4e20b..0a9857d7240a 100644 --- a/models/demos/llama3/tt/generator.py +++ b/models/demos/llama3/tt/generator.py @@ -25,8 +25,6 @@ get_max_prefill_chunk_size, ) -from time import time - class LlamaGenerator: def __init__(self, model, model_args, mesh_device, tokenizer=None, formatter=None): @@ -216,28 +214,9 @@ def _decode_forward_no_trace_text( rot_mats=tt_rot_mats, page_table=tt_page_table, kv_cache=kv_cache, + argmax_on_device=argmax_on_device, ) - # Gather the output across all devices and untilize the tensor (for argmax) - if self.model.args.num_devices > 1: - if self.model.args.is_galaxy: - tt_logits = ttnn.all_gather( - tt_logits, - dim=3, - num_links=2, - cluster_axis=0, - mesh_device=self.model.mesh_device, - topology=ttnn.Topology.Linear, - ) - else: - tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear) - tt_logits = ttnn.untilize(tt_logits, use_multicore=True) - - if argmax_on_device: - tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1 - tt_logits, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # ,output_tensor=tokens - ) - return tt_logits def _capture_trace_text( @@ -265,26 +244,9 @@ def _capture_trace_text( trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0) transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs) - tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache) - - if self.model.args.num_devices > 1: - if self.model.args.is_galaxy: - tt_out_trace = ttnn.all_gather( - tt_out_trace, - dim=3, - num_links=2, - cluster_axis=0, - mesh_device=self.model.mesh_device, - topology=ttnn.Topology.Linear, - ) - else: - tt_out_trace = ttnn.all_gather(tt_out_trace, dim=3, num_links=1, topology=ttnn.Topology.Linear) - tt_out_trace = ttnn.untilize(tt_out_trace, use_multicore=True) - - if argmax_on_device: - tt_out_trace = ttnn.argmax( # TODO Add multicore support to batch > 1 - tt_out_trace, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # , output_tensor=tokens - ) + tt_out_trace = self.model.ttnn_decode_forward( + *transformed_inputs, kv_cache=kv_cache, argmax_on_device=argmax_on_device + ) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") @@ -474,6 +436,7 @@ def decode_forward( cross_page_table=None, enable_trace=True, read_from_device=True, + argmax_on_device=False, ): decode_kwargs = { "position_id": start_pos, @@ -484,6 +447,7 @@ def decode_forward( "page_table": page_table, "kv_cache": kv_cache, "cross_page_table": cross_page_table, + "argmax_on_device": argmax_on_device, } if enable_trace: tt_logits = self._easy_trace(**decode_kwargs) @@ -509,6 +473,7 @@ def _decode_forward_no_trace( page_table=None, kv_cache=None, cross_page_table=None, + argmax_on_device=False, ): """ Performs text decode step. @@ -547,22 +512,9 @@ def _decode_forward_no_trace( page_table=tt_page_table, kv_cache=kv_cache, cross_page_table=tt_cross_page_table, + argmax_on_device=argmax_on_device, ) - if self.model.args.num_devices > 1: - if self.model.args.is_galaxy: - tt_logits = ttnn.all_gather( - tt_logits, - dim=3, - num_links=2, - cluster_axis=0, - mesh_device=self.model.mesh_device, - topology=ttnn.Topology.Linear, - ) - else: - tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear) - tt_logits = ttnn.untilize(tt_logits, use_multicore=True) - return tt_logits def _capture_trace( @@ -575,6 +527,7 @@ def _capture_trace( page_table=None, kv_cache=None, cross_page_table=None, + argmax_on_device=False, ): """ Captures a trace for the decode_forward method. @@ -607,6 +560,7 @@ def _capture_trace( page_table=tt_page_table, kv_cache=kv_cache, cross_page_table=tt_cross_page_table, + argmax_on_device=argmax_on_device, ) logger.info("Done Compiling Model") @@ -676,6 +630,7 @@ def _capture_trace( page_table=tt_page_table, kv_cache=kv_cache, cross_page_table=tt_cross_page_table, + argmax_on_device=argmax_on_device, ) ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) @@ -752,7 +707,7 @@ def _decode_forward_trace( ), ) - ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=True) + ttnn.execute_trace(self.mesh_device, trace_id, cq_id=0, blocking=False) return trace_logits_rm @@ -766,6 +721,7 @@ def _easy_trace( page_table=None, kv_cache=None, cross_page_table=None, + argmax_on_device=False, ): """ Tracing is easy! Just call this method and we'll handle tracing for you. @@ -790,6 +746,7 @@ def _easy_trace( page_table=page_table, kv_cache=kv_cache, cross_page_table=cross_page_table, + argmax_on_device=argmax_on_device, ) self.trace_id = trace_id self.trace_inputs = { @@ -901,8 +858,8 @@ def chat_completion( top_p: float = 0.9, max_gen_len=None, ): - # if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: - # max_gen_len = self.model.configuration.max_seq_len - 1 + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.configuration.max_seq_len: + max_gen_len = self.model.configuration.max_seq_len - 1 tokens = [] diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 312e29e54bc1..670a90c3fde1 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -248,8 +248,7 @@ def process_output_decode(self, tt_out, B, S=1, argmax_on_device=False): """ if argmax_on_device: tt_out = ttnn.to_torch( - # tt_out.cpu(blocking=True, cq_id=1), - tt_out, + tt_out, # tt_out.cpu(blocking=True, cq_id=1), mesh_composer=ttnn.ConcatMesh2dToTensor( self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), @@ -300,12 +299,13 @@ def ttnn_decode_forward( rot_mats, page_table=None, kv_cache=None, + argmax_on_device=False, ): """ This method will take device tensors and any other args to run forward. It returns ttnn device tensors. """ - return self.forward( + tt_logits = self.forward( x, current_pos, rot_mats=rot_mats, @@ -314,6 +314,28 @@ def ttnn_decode_forward( kv_cache=kv_cache, ) + # Gather the output across all devices and untilize the tensor (for argmax) + if self.args.num_devices > 1: + if self.args.is_galaxy: + tt_logits = ttnn.all_gather( + tt_logits, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.mesh_device, + topology=ttnn.Topology.Linear, + ) + else: + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_logits = ttnn.untilize(tt_logits, use_multicore=True) + + if argmax_on_device: + tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1 + tt_logits, dim=3, use_multicore=False if self.args.max_batch_size > 1 else True # ,output_tensor=tokens + ) + + return tt_logits + def forward( self, x: ttnn.Tensor,