Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding specs of line_all_gather for Llama3-TG #11172

Closed
kpaigwar opened this issue Aug 7, 2024 · 2 comments
Closed

Sharding specs of line_all_gather for Llama3-TG #11172

kpaigwar opened this issue Aug 7, 2024 · 2 comments
Assignees
Labels
bug Something isn't working op_cat: ccl P1 perf for issues tracking performance problems/improvements

Comments

@kpaigwar
Copy link
Contributor

kpaigwar commented Aug 7, 2024

########################################################################################
# Spec 1
########################################################################################
fused_query_key_value = {'shape' : [1, 1, 32, 1280], 
                        'shard_shape' : (32, 32)}
all_gather_output = {'shape' : [4, 1, 32, 1280], 
                    'shard_shape' : (32*4, 32)}
output_mem_config = ttnn.create_sharded_memory_config(
                            shape=(32*4, 32),
                            core_grid=ttnn.CoreGrid(y=5, x=8),
                            strategy=ttnn.ShardStrategy.WIDTH,
                            orientation=ttnn.ShardOrientation.ROW_MAJOR,
                            use_height_and_width_as_shard_shape=True,
                        )
gathered_tensor = ttnn.line_all_gather(fused_query_key_value, dim=0, num_links=2, 
                                       cluster_axis=1, device_mesh=self.device_mesh, 
                                       memory_config=output_mem_config)
########################################################################################
# Spec 2
########################################################################################
attn_output_tensor = {'shape' : [1, 1, 32, 2048], 
                        'shard_shape' : [1, 1, 32, 64]}
all_gather_output = {'shape' : [8, 1, 32, 2048], 
                    'shard_shape' : (32*8, 64)}
output_mem_config = ttnn.create_sharded_memory_config(
                            shape=(32*8, 64),
                            core_grid=ttnn.CoreGrid(y=4, x=8),
                            strategy=ttnn.ShardStrategy.WIDTH,
                            orientation=ttnn.ShardOrientation.ROW_MAJOR,
                            use_height_and_width_as_shard_shape=True,
                        )
gathered_tensor = ttnn.line_all_gather(attn_output_tensor, dim=0, num_links=2, 
                                       cluster_axis=0, device_mesh=self.device_mesh, 
                                       memory_config=output_mem_config)
@kpaigwar
Copy link
Contributor Author

kpaigwar commented Aug 7, 2024

fyi @SeanNijjar @cglagovich

@SeanNijjar SeanNijjar self-assigned this Oct 21, 2024
@SeanNijjar SeanNijjar added bug Something isn't working P1 op_cat: ccl perf for issues tracking performance problems/improvements labels Oct 21, 2024
@SeanNijjar
Copy link
Contributor

Closing. @kpaigwar confirmed functional correctness on TG

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working op_cat: ccl P1 perf for issues tracking performance problems/improvements
Projects
None yet
Development

No branches or pull requests

2 participants