Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocess Conv2D weights on Device #18272

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
35ee3ce
#18185: Change order of pad & permute
sankarmanoj-tt Feb 24, 2025
5061322
#0: First commit for loading weights on device
sankarmanoj-tt Jan 12, 2025
f5dc16e
#0: WIP Conv device weights
sankarmanoj-tt Jan 13, 2025
28cbdcb
#0: WIP Conv device weights
sankarmanoj-tt Jan 13, 2025
9f417cc
#0: Conv device weights
sankarmanoj-tt Jan 14, 2025
3b5dd5a
#0: 80% pass for loading weights on device
sankarmanoj-tt Jan 15, 2025
3e56209
#0: Shallow conv support
sankarmanoj-tt Jan 15, 2025
3c755ac
#0: rebase fix
sankarmanoj-tt Jan 16, 2025
4957959
#0: Fix pad by using multicore
sankarmanoj-tt Jan 16, 2025
dd3a1fc
#0: Fix pad by using multicore
sankarmanoj-tt Jan 16, 2025
0a348ce
#0: Fix OOM for pad
sankarmanoj-tt Jan 29, 2025
0c5c07d
#0: Fix device weights
sankarmanoj-tt Jan 30, 2025
64f345f
#0: Re-enable tests
sankarmanoj-tt Jan 30, 2025
f21f411
#0: Re-enable tests
sankarmanoj-tt Jan 30, 2025
91ad04b
#0: Re-enable tests
sankarmanoj-tt Jan 30, 2025
338f326
#0: Fix OOM for pad
sankarmanoj-tt Jan 30, 2025
d9a4ea0
#0: Build fix
sankarmanoj-tt Feb 3, 2025
3d2f880
#0: Build fix
sankarmanoj-tt Feb 3, 2025
407b23d
#0: Re-enable transpose shards for Conv2D Unit Tests
sankarmanoj-tt Feb 5, 2025
f55412f
#0: Tests fix
sankarmanoj-tt Feb 5, 2025
f988ddc
#0: Tests fix
sankarmanoj-tt Feb 6, 2025
d59351b
#0: Rebase fi
sankarmanoj-tt Feb 13, 2025
b08aa14
#0: Tests fix
sankarmanoj-tt Feb 13, 2025
d150e7b
#0: Skip weights bfloat8 on grayskull
sankarmanoj-tt Feb 13, 2025
c513634
#0: Reverted types
sankarmanoj-tt Feb 15, 2025
1890433
#0: Add flag for always preprocessing weights
sankarmanoj-tt Feb 18, 2025
e6ec890
#0: Preprocess bias on device
sankarmanoj-tt Feb 18, 2025
7dcfd06
#0: Fix conv bias
sankarmanoj-tt Feb 20, 2025
d046831
#0: Rebase fix
sankarmanoj-tt Feb 20, 2025
cec786e
#0: Rebase fix
sankarmanoj-tt Feb 20, 2025
6d48186
#0: Bug fix
sankarmanoj-tt Feb 21, 2025
4bd87af
#0: Skip test on N300
sankarmanoj-tt Feb 21, 2025
361792d
Merge branch 'smanoj/18185' into smanoj/conv_device_weights
sankarmanoj-tt Feb 25, 2025
4d5afb1
#0: Fix sweep
sankarmanoj-tt Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def run_conv(
config_override,
dilation=1,
use_shallow_conv_variant=False,
transpose_shards=True, # https://github.com/tenstorrent/tt-metal/issues/17897
fp32_accum=False,
packer_l1_acc=False,
output_layout=ttnn.TILE_LAYOUT,
Expand All @@ -72,6 +73,7 @@ def run_conv(
weight_mesh_mapper=None,
output_mesh_composer=None,
enable_split_reader=False,
preprocess_weights_on_device=True,
):
if isinstance(device, ttnn.MeshDevice):
assert input_mesh_mapper is not None, "Expected mesh mapper for input tensor when using device mesh"
Expand All @@ -91,7 +93,7 @@ def run_conv(
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))

torch_weight_tensor = randomize_torch_tensor(torch_tensor_map, conv_weight_shape)
torch_bias_tensor = randomize_torch_tensor(torch_tensor_map, conv_bias_shape) if has_bias else None
torch_bias_tensor = randomize_torch_tensor(torch_tensor_map, conv_bias_shape) * 10 if has_bias else None

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
Expand Down Expand Up @@ -134,6 +136,9 @@ def run_conv(
enable_split_reader=enable_split_reader,
enable_subblock_padding=False,
output_layout=output_layout,
transpose_shards=transpose_shards,
preprocess_weights_on_device=preprocess_weights_on_device,
always_preprocess_weights=True,
)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
Expand All @@ -153,7 +158,7 @@ def run_conv(
conv_config.override_sharding_config = True
print("Setting num_cores_nhw to 98")

[tt_output_tensor_on_device, [out_height, out_width]] = ttnn.conv2d(
[tt_output_tensor_on_device, [out_height, out_width], [d_w, d_b]] = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
Expand All @@ -174,8 +179,8 @@ def run_conv(
groups=groups,
memory_config=memory_config,
return_output_dim=True,
return_weights_and_bias=True,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor, mesh_composer=output_mesh_composer)

Expand All @@ -191,6 +196,8 @@ def run_conv(

if not fp32_accum:
pcc = 0.985
if input_channels * filter_height * filter_width > 10000:
pcc = 0.97
elif math_fidelity == ttnn.MathFidelity.LoFi and activations_dtype == ttnn.bfloat8_b:
pcc = 0.996
else:
Expand Down Expand Up @@ -384,6 +391,9 @@ def test_conv_features(
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")

if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat16 and packer_l1_acc and fp32_accum:
pytest.skip("skipping due to pack_untilize_dst issue!")

run_conv(
device,
torch_tensor_map,
Expand All @@ -407,6 +417,7 @@ def test_conv_features(
has_bias=True,
fp32_accum=fp32_accum,
packer_l1_acc=packer_l1_acc,
preprocess_weights_on_device=True,
)


Expand Down Expand Up @@ -778,7 +789,7 @@ def test_conv_for_segformer_512x512(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -961,6 +972,7 @@ def test_resnet50_conv_wh(
pad_w,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=packer_l1_acc,
fp32_accum=False,
has_bias=has_bias,
Expand Down Expand Up @@ -1022,6 +1034,7 @@ def test_conv_mem_config_wh(
shard_layout=shard_layout,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=True,
fp32_accum=False,
has_bias=True,
Expand Down Expand Up @@ -1207,7 +1220,7 @@ def test_resnet50_conv_wh_fp32(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -1349,7 +1362,7 @@ def test_sd_conv(
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"fp32_accum",
Expand Down Expand Up @@ -1490,7 +1503,7 @@ def test_sd_conv_wh(
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
Expand Down Expand Up @@ -1642,6 +1655,7 @@ def test_unet_conv_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
)
Expand Down Expand Up @@ -1740,6 +1754,7 @@ def test_unet_conv_groups_2_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
groups=groups,
Expand Down Expand Up @@ -1837,6 +1852,7 @@ def test_unet_conv_groups_4_6_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
groups=groups,
)
Expand Down Expand Up @@ -1935,12 +1951,14 @@ def test_unet_conv_groups_8_wh(
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_shards=True, ## use RM (transpose_mcast=False) with 2D on WH
output_layout=output_layout,
auto_shard=auto_shard,
groups=groups,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, config_override",
Expand Down Expand Up @@ -2002,6 +2020,7 @@ def test_halo_reshard_conv(
)


@skip_for_grayskull()
@pytest.mark.skip("New API needs to be tested")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2243,6 +2262,7 @@ def test_conv_groups(
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups",
Expand Down Expand Up @@ -2363,6 +2383,7 @@ def test_yolov4_conv_groups_larger_than_one(
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
" output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups",
Expand Down Expand Up @@ -2651,6 +2672,7 @@ def test_shallow_conv_with_tiled_input(device):

# Tests running conv2d which maps to matmul w/o sharding the input tensor.
# Output tensor is in DRAM.
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("tiled_input", [True, False])
@pytest.mark.parametrize("input_on_device", [True, False])
Expand Down Expand Up @@ -2776,6 +2798,9 @@ def test_small_in_large_out_channels_auto_shard(device, torch_tensor_map):
padding = (0, 0)
height = 128
width = 128
if device.core_grid.y != 8 and is_wormhole_b0():
pytest.skip("Needs 8x8 grid for wormhole_b0")

run_conv(
device,
torch_tensor_map,
Expand Down
130 changes: 0 additions & 130 deletions tests/ttnn/unit_tests/operations/test_prepare_conv_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,133 +196,3 @@ def test_prepare_conv_weights(
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing


@skip_for_grayskull()
@skip_for_blackhole()
# @skip_for_wormhole_b0()
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override",
(
# rn50 layer1
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
(20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
),
)
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["has_bias", "no_bias"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 2**15}], indirect=True)
def test_prepare_bias(
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
packer_l1_acc,
config_override,
has_bias,
device,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

if batch_size == 20 and (
output_channels == 64 or (stride_h == 2 and (output_channels == 256 or output_channels == 128))
):
pytest.skip("Skipping test because it won't fit in L1!")

inp_shape = (batch_size, input_channels, input_height, input_width)
conv_weight_shape = (output_channels, input_channels, filter_height, filter_width)
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16)
torch_input_tensor = torch.randn(inp_shape, dtype=torch.bfloat16)
torch_bias_tensor = torch.randn((1, 1, 1, output_channels), dtype=torch.bfloat16) if has_bias else None

torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(1, 1),
groups=1,
).permute(0, 2, 3, 1)

tt_input_tensor = ttnn.from_torch(torch_input_tensor.transpose(-3, -2).transpose(-2, -1), ttnn.bfloat16)
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16) if has_bias else None

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
input_channels_alignment=(16 if input_channels == 16 and input_height == 115 else 32),
enable_act_double_buffer=False,
enable_split_reader=False,
enable_subblock_padding=False,
)
compute_config = ttnn.init_device_compute_kernel_config(device.arch(), packer_l1_acc=packer_l1_acc)
if config_override and "act_block_h" in config_override:
conv_config.act_block_h_override = config_override["act_block_h"]

if config_override and "act_block_w_div" in config_override:
conv_config.act_block_w_div = config_override["act_block_w_div"]

if config_override and "num_cores_nhw" in config_override:
if config_override["num_cores_nhw"] == 98:
conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange((0, 0), (11, 7)), ttnn.CoreRange((0, 8), (1, 8))})
conv_config.override_sharding_config = True
print("Setting num_cores_nhw to 98")

conv_kwargs = {
"input_layout": ttnn.ROW_MAJOR_LAYOUT,
"in_channels": input_channels,
"out_channels": output_channels,
"batch_size": batch_size,
"input_height": input_height,
"input_width": input_width,
"kernel_size": (filter_height, filter_width),
"stride": (stride_h, stride_w),
"padding": (pad_h, pad_w),
"dilation": (1, 1),
"groups": 1,
"device": device,
"conv_config": conv_config,
}

tt_input_tensor = ttnn.to_device(tt_input_tensor, device)

tt_bias_tensor_formatted = (
ttnn.prepare_conv_bias(
bias_tensor=tt_bias_tensor, input_memory_config=tt_input_tensor.memory_config(), **conv_kwargs
)
if has_bias
else None
)

tt_bias_tensor_formatted = ttnn.to_device(tt_bias_tensor_formatted, device) if has_bias else None
(k := next(iter(conv_kwargs)), conv_kwargs.pop(k)) ##removing 1st element from dict
tt_output_tensor_on_device = ttnn.conv2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
bias_tensor=tt_bias_tensor_formatted,
**conv_kwargs,
compute_config=compute_config,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]
torch_output_tensor = torch_output_tensor.reshape(torch_out_golden_tensor.shape)

pcc = 0.99
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=pcc)
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing
Loading
Loading