diff --git a/models/demos/llama3/demo/simple_text_demo.py b/models/demos/llama3/demo/simple_text_demo.py index 0bcc016482d6..2eabfc02cbab 100644 --- a/models/demos/llama3/demo/simple_text_demo.py +++ b/models/demos/llama3/demo/simple_text_demo.py @@ -328,6 +328,7 @@ def test_llama_demo_text( Simple Llama demo with limited dependence on reference code. """ mesh_device.enable_async(True) + enable_trace = True # Use tracing for better perf print_to_file = False # Enable this flag to print the output of all users to a file @@ -440,7 +441,7 @@ def test_llama_demo_text( user_done = [False] * batch_size # Keeps track when a user reaches EoD token - # Set sampling mode + # TODO Argmax on device is only supported for batch_size=1 argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True # Initial positions @@ -466,21 +467,26 @@ def test_llama_demo_text( logits = generator.decode_forward_text( out_tok, current_pos, - enable_trace=True, + enable_trace=enable_trace, page_table=page_table, kv_cache=tt_kv_cache, + argmax_on_device=argmax_on_device, ) - # TODO Miguel - Re-add argmax on device - # Fix use case with temperature > 0 # Get the next token - _, out_tok = sample_host( - logits, - None, - temperature=sampling_params["temperature"], - top_p=sampling_params["top_p"], - on_host=True, - ) + if argmax_on_device: + out_tok = logits + if out_tok.dim() == 1: + out_tok = out_tok.unsqueeze(0) + else: + # TODO Fix use case with temperature > 0 + _, out_tok = sample_host( + logits, + None, + temperature=sampling_params["temperature"], + top_p=sampling_params["top_p"], + on_host=True, + ) if iteration == 0: # First iteration will account the compile time profiler.end(f"compile_decode_time", iteration=batch_idx) diff --git a/models/demos/llama3/tt/generator.py b/models/demos/llama3/tt/generator.py index ba025e73663c..5a31f9c4e20b 100644 --- a/models/demos/llama3/tt/generator.py +++ b/models/demos/llama3/tt/generator.py @@ -176,12 +176,14 @@ def decode_forward_text( kv_cache=None, enable_trace=True, read_from_device=True, + argmax_on_device=False, ): decode_kwargs = { "current_pos": start_pos, "tokens": tokens, "page_table": page_table, "kv_cache": kv_cache, + "argmax_on_device": argmax_on_device, } if enable_trace: tt_logits = self._easy_trace_text(**decode_kwargs) @@ -189,7 +191,7 @@ def decode_forward_text( tt_logits = self._decode_forward_no_trace_text(**decode_kwargs) if read_from_device: - return self.read_decode_output(tt_logits, tokens.shape[0]) + return self.read_decode_output(tt_logits, tokens.shape[0], argmax_on_device) else: return tt_logits @@ -199,6 +201,7 @@ def _decode_forward_no_trace_text( current_pos, page_table=None, kv_cache=None, + argmax_on_device=False, ): """ Performs text decode step. @@ -215,6 +218,26 @@ def _decode_forward_no_trace_text( kv_cache=kv_cache, ) + # Gather the output across all devices and untilize the tensor (for argmax) + if self.model.args.num_devices > 1: + if self.model.args.is_galaxy: + tt_logits = ttnn.all_gather( + tt_logits, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.model.mesh_device, + topology=ttnn.Topology.Linear, + ) + else: + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_logits = ttnn.untilize(tt_logits, use_multicore=True) + + if argmax_on_device: + tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1 + tt_logits, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # ,output_tensor=tokens + ) + return tt_logits def _capture_trace_text( @@ -223,13 +246,16 @@ def _capture_trace_text( current_pos, page_table=None, kv_cache=None, + argmax_on_device=False, ): """ Captures a trace for the decode_forward method. """ # Compile run - self._decode_forward_no_trace_text(tokens, current_pos, page_table=page_table, kv_cache=kv_cache) + self._decode_forward_no_trace_text( + tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device + ) logger.info("Done Compiling Model") # Get inputs ready for trace run @@ -241,9 +267,27 @@ def _capture_trace_text( transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs) tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache) + if self.model.args.num_devices > 1: + if self.model.args.is_galaxy: + tt_out_trace = ttnn.all_gather( + tt_out_trace, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.model.mesh_device, + topology=ttnn.Topology.Linear, + ) + else: + tt_out_trace = ttnn.all_gather(tt_out_trace, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_out_trace = ttnn.untilize(tt_out_trace, use_multicore=True) + + if argmax_on_device: + tt_out_trace = ttnn.argmax( # TODO Add multicore support to batch > 1 + tt_out_trace, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # , output_tensor=tokens + ) + ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0) logger.info("Done Capturing Decode Trace") - return trace_id, tt_out_trace, *device_inputs def _decode_forward_trace_text( @@ -275,13 +319,14 @@ def _easy_trace_text( current_pos, page_table=None, kv_cache=None, + argmax_on_device=False, ): """ Tracing is easy! Just call this method and we'll handle tracing for you. """ if not hasattr(self, "trace_id_text"): trace_id, tt_out_trace, *device_inputs = self._capture_trace_text( - tokens, current_pos, page_table=page_table, kv_cache=kv_cache + tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device ) self.trace_id_text = trace_id self.trace_inputs_text = device_inputs @@ -450,8 +495,8 @@ def decode_forward( else: return tt_logits - def read_decode_output(self, tt_logits, unpadded_batch): - logits = self.model.process_output_decode(tt_logits, B=unpadded_batch, S=1) + def read_decode_output(self, tt_logits, unpadded_batch, argmax_on_device=False): + logits = self.model.process_output_decode(tt_logits, B=unpadded_batch, S=1, argmax_on_device=argmax_on_device) return logits def _decode_forward_no_trace( @@ -504,6 +549,20 @@ def _decode_forward_no_trace( cross_page_table=tt_cross_page_table, ) + if self.model.args.num_devices > 1: + if self.model.args.is_galaxy: + tt_logits = ttnn.all_gather( + tt_logits, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.model.mesh_device, + topology=ttnn.Topology.Linear, + ) + else: + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_logits = ttnn.untilize(tt_logits, use_multicore=True) + return tt_logits def _capture_trace( diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index d11aac69904c..ead481c1278b 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -295,6 +295,8 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True pt_out = torch.argmax(pt_input, dim=-1) if mesh_device is None: + if pt_out.dim() == 1: # if sampling a single token re-add the batch dim to the tensor + pt_out = pt_out.unsqueeze(0) return None, pt_out if on_host: return ( diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 429d1d25c7b3..312e29e54bc1 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -242,23 +242,22 @@ def process_output_prefill(self, tt_out, last_token_idx): )[0, 0, last_token_idx, : self.vocab_size] return logits - def process_output_decode(self, tt_out, B, S=1): + def process_output_decode(self, tt_out, B, S=1, argmax_on_device=False): """ - Input is ttnn device tensor of logits. Output is torch logits tensor + Input is ttnn device tensor of logits. Output is torch logits tensor or the generated token if argmax on device """ - if self.args.num_devices > 1: - if self.args.is_galaxy: - tt_out = ttnn.all_gather( - tt_out, - dim=3, - num_links=2, - cluster_axis=0, - mesh_device=self.mesh_device, - topology=ttnn.Topology.Linear, - ) - else: - tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) - tt_out = ttnn.untilize(tt_out, use_multicore=True) + if argmax_on_device: + tt_out = ttnn.to_torch( + # tt_out.cpu(blocking=True, cq_id=1), + tt_out, + mesh_composer=ttnn.ConcatMesh2dToTensor( + self.mesh_device, + dims=(3, 1) if self.args.is_galaxy else (1, -1), + mesh_shape=self.args.cluster_shape, + ), + )[0, 0, 0, :B] + return tt_out + if self.args.num_devices > 1: tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() else: