Skip to content

Commit

Permalink
Fix crash on printing of 0d tensors (#18377)
Browse files Browse the repository at this point in the history
### Ticket
#18358

### Problem description
Currently we get a segfault trying to print a 0d tensor

### What's changed
Added a check for 0d tensor when printing
Added a test for it

### Checklist
- [x] [All post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/13554456380)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Feb 27, 2025
1 parent 0508927 commit 680b042
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions tests/ttnn/unit_tests/test_print_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@ def test_print(device, dtype, layout, profile, deallocate):
# print("\\n".join(str(tensor).split("\n")))

assert tensor_as_string == GOLDEN_TENSOR_STRINGS[(dtype, layout)]


def test_print_0d(device):
torch_tensor = torch.ones((), dtype=torch.bfloat16)
tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=device)
assert str(tensor) == "ttnn.Tensor([ 1.00000], shape=Shape([]), dtype=DataType::BFLOAT16, layout=Layout::TILE)"
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ void to_string_row_major(
const std::size_t buffer_offset,
int64_t rank,
int64_t dim = 0) {
auto stride = strides[dim];
auto stride = dim < strides.size() ? strides[dim] : 0;

std::string spaces = std::string(TENSOR_TYPE_STRING_PLUS_OPEN_PARENTHESIS_LENGTH + dim, ' ');
std::string before;
Expand All @@ -376,7 +376,7 @@ void to_string_row_major(
ss << spaces;
}
ss << "[";
auto dimension_shortener = get_dimension_shortener(shape[-rank]);
auto dimension_shortener = get_dimension_shortener(rank != 0 ? shape[-rank] : 1);
for (std::size_t index = 0;
dimension_shortener.print_parenthesis_and_advance_index_if_reached_half_of_max_and_check_if_loop_is_done(
ss, index, before, after);
Expand All @@ -396,7 +396,7 @@ void to_string_row_major(
} else {
print_datum(ss, buffer[buffer_offset + index]);
}
print_trailing_comma(ss, index, shape[-rank], after_comma);
print_trailing_comma(ss, index, rank != 0 ? shape[-rank] : 1, after_comma);
}
ss << "]";
}
Expand Down

0 comments on commit 680b042

Please sign in to comment.