Skip to content

Commit

Permalink
#0: Add argmax on device to llama3 generator
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jan 17, 2025
1 parent 523fa73 commit a02d4d0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 32 deletions.
28 changes: 17 additions & 11 deletions models/demos/llama3/demo/simple_text_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_llama_demo_text(
Simple Llama demo with limited dependence on reference code.
"""
mesh_device.enable_async(True)
enable_trace = True # Use tracing for better perf

print_to_file = False # Enable this flag to print the output of all users to a file

Expand Down Expand Up @@ -440,7 +441,7 @@ def test_llama_demo_text(

user_done = [False] * batch_size # Keeps track when a user reaches EoD token

# Set sampling mode
# TODO Argmax on device is only supported for batch_size=1
argmax_on_device = False if (batch_size > 1 or sampling_params["temperature"] != 0) else True

# Initial positions
Expand All @@ -466,21 +467,26 @@ def test_llama_demo_text(
logits = generator.decode_forward_text(
out_tok,
current_pos,
enable_trace=True,
enable_trace=enable_trace,
page_table=page_table,
kv_cache=tt_kv_cache,
argmax_on_device=argmax_on_device,
)

# TODO Miguel - Re-add argmax on device
# Fix use case with temperature > 0
# Get the next token
_, out_tok = sample_host(
logits,
None,
temperature=sampling_params["temperature"],
top_p=sampling_params["top_p"],
on_host=True,
)
if argmax_on_device:
out_tok = logits
if out_tok.dim() == 1:
out_tok = out_tok.unsqueeze(0)
else:
# TODO Fix use case with temperature > 0
_, out_tok = sample_host(
logits,
None,
temperature=sampling_params["temperature"],
top_p=sampling_params["top_p"],
on_host=True,
)

if iteration == 0: # First iteration will account the compile time
profiler.end(f"compile_decode_time", iteration=batch_idx)
Expand Down
71 changes: 65 additions & 6 deletions models/demos/llama3/tt/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,22 @@ def decode_forward_text(
kv_cache=None,
enable_trace=True,
read_from_device=True,
argmax_on_device=False,
):
decode_kwargs = {
"current_pos": start_pos,
"tokens": tokens,
"page_table": page_table,
"kv_cache": kv_cache,
"argmax_on_device": argmax_on_device,
}
if enable_trace:
tt_logits = self._easy_trace_text(**decode_kwargs)
else:
tt_logits = self._decode_forward_no_trace_text(**decode_kwargs)

if read_from_device:
return self.read_decode_output(tt_logits, tokens.shape[0])
return self.read_decode_output(tt_logits, tokens.shape[0], argmax_on_device)
else:
return tt_logits

Expand All @@ -199,6 +201,7 @@ def _decode_forward_no_trace_text(
current_pos,
page_table=None,
kv_cache=None,
argmax_on_device=False,
):
"""
Performs text decode step.
Expand All @@ -215,6 +218,26 @@ def _decode_forward_no_trace_text(
kv_cache=kv_cache,
)

# Gather the output across all devices and untilize the tensor (for argmax)
if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_logits = ttnn.all_gather(
tt_logits,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_logits = ttnn.untilize(tt_logits, use_multicore=True)

if argmax_on_device:
tt_logits = ttnn.argmax( # TODO Add multicore support to batch > 1
tt_logits, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # ,output_tensor=tokens
)

return tt_logits

def _capture_trace_text(
Expand All @@ -223,13 +246,16 @@ def _capture_trace_text(
current_pos,
page_table=None,
kv_cache=None,
argmax_on_device=False,
):
"""
Captures a trace for the decode_forward method.
"""

# Compile run
self._decode_forward_no_trace_text(tokens, current_pos, page_table=page_table, kv_cache=kv_cache)
self._decode_forward_no_trace_text(
tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device
)
logger.info("Done Compiling Model")

# Get inputs ready for trace run
Expand All @@ -241,9 +267,27 @@ def _capture_trace_text(
transformed_inputs = self.model.transform_decode_inputs_device(*device_inputs)
tt_out_trace = self.model.ttnn_decode_forward(*transformed_inputs, kv_cache=kv_cache)

if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_out_trace = ttnn.all_gather(
tt_out_trace,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_out_trace = ttnn.all_gather(tt_out_trace, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out_trace = ttnn.untilize(tt_out_trace, use_multicore=True)

if argmax_on_device:
tt_out_trace = ttnn.argmax( # TODO Add multicore support to batch > 1
tt_out_trace, dim=3, use_multicore=False if tokens.shape[0] > 1 else True # , output_tensor=tokens
)

ttnn.end_trace_capture(self.mesh_device, trace_id, cq_id=0)
logger.info("Done Capturing Decode Trace")

return trace_id, tt_out_trace, *device_inputs

def _decode_forward_trace_text(
Expand Down Expand Up @@ -275,13 +319,14 @@ def _easy_trace_text(
current_pos,
page_table=None,
kv_cache=None,
argmax_on_device=False,
):
"""
Tracing is easy! Just call this method and we'll handle tracing for you.
"""
if not hasattr(self, "trace_id_text"):
trace_id, tt_out_trace, *device_inputs = self._capture_trace_text(
tokens, current_pos, page_table=page_table, kv_cache=kv_cache
tokens, current_pos, page_table=page_table, kv_cache=kv_cache, argmax_on_device=argmax_on_device
)
self.trace_id_text = trace_id
self.trace_inputs_text = device_inputs
Expand Down Expand Up @@ -450,8 +495,8 @@ def decode_forward(
else:
return tt_logits

def read_decode_output(self, tt_logits, unpadded_batch):
logits = self.model.process_output_decode(tt_logits, B=unpadded_batch, S=1)
def read_decode_output(self, tt_logits, unpadded_batch, argmax_on_device=False):
logits = self.model.process_output_decode(tt_logits, B=unpadded_batch, S=1, argmax_on_device=argmax_on_device)
return logits

def _decode_forward_no_trace(
Expand Down Expand Up @@ -504,6 +549,20 @@ def _decode_forward_no_trace(
cross_page_table=tt_cross_page_table,
)

if self.model.args.num_devices > 1:
if self.model.args.is_galaxy:
tt_logits = ttnn.all_gather(
tt_logits,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.model.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_logits = ttnn.untilize(tt_logits, use_multicore=True)

return tt_logits

def _capture_trace(
Expand Down
2 changes: 2 additions & 0 deletions models/demos/llama3/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True
pt_out = torch.argmax(pt_input, dim=-1)

if mesh_device is None:
if pt_out.dim() == 1: # if sampling a single token re-add the batch dim to the tensor
pt_out = pt_out.unsqueeze(0)
return None, pt_out
if on_host:
return (
Expand Down
29 changes: 14 additions & 15 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,22 @@ def process_output_prefill(self, tt_out, last_token_idx):
)[0, 0, last_token_idx, : self.vocab_size]
return logits

def process_output_decode(self, tt_out, B, S=1):
def process_output_decode(self, tt_out, B, S=1, argmax_on_device=False):
"""
Input is ttnn device tensor of logits. Output is torch logits tensor
Input is ttnn device tensor of logits. Output is torch logits tensor or the generated token if argmax on device
"""
if self.args.num_devices > 1:
if self.args.is_galaxy:
tt_out = ttnn.all_gather(
tt_out,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if argmax_on_device:
tt_out = ttnn.to_torch(
# tt_out.cpu(blocking=True, cq_id=1),
tt_out,
mesh_composer=ttnn.ConcatMesh2dToTensor(
self.mesh_device,
dims=(3, 1) if self.args.is_galaxy else (1, -1),
mesh_shape=self.args.cluster_shape,
),
)[0, 0, 0, :B]
return tt_out

if self.args.num_devices > 1:
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
else:
Expand Down

0 comments on commit a02d4d0

Please sign in to comment.