Skip to content

Commit

Permalink
Remove Shape usage from MultiDeviceStorage (#16841)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
We're continuing to remove usages of Shape/LegacyShape through the
codebase

### What's changed
Replaced Shape usage in MultiDeviceStorage/MultiDeviceHostStorage with
TensorSpec

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12827697948)
- [x] [Model regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12827838374)
- [x] [T3K unit tests CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12827848878)
- [x] [T3K frequent CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12827843535)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Jan 17, 2025
1 parent 58fb827 commit 2e8c7e7
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 113 deletions.
6 changes: 3 additions & 3 deletions tests/ttnn/unit_tests/gtests/test_multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ using namespace tt::tt_metal;

Tensor create_host_multi_device_tensor(const Tensor& tensor, const ReplicateTensor& strategy) {
std::vector<OwnedBuffer> owned_buffers;
std::vector<ttnn::Shape> shapes;
std::vector<ttnn::TensorSpec> specs;

for (int i = 0; i < strategy.replication_factor; i++) {
owned_buffers.push_back(std::get<OwnedStorage>(tensor.get_storage()).buffer);
shapes.push_back(tensor.get_shape());
specs.push_back(tensor.get_tensor_spec());
}

return Tensor{
MultiDeviceHostStorage(strategy, owned_buffers, shapes),
MultiDeviceHostStorage(strategy, owned_buffers, specs),
tensor.get_legacy_shape(),
tensor.get_dtype(),
tensor.get_layout()};
Expand Down
11 changes: 6 additions & 5 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ template <class T, DataType TensorType>
[[nodiscard]] tt::tt_metal::Tensor from_xtensors_to_host(
const std::vector<xt::xarray<T>>& buffers, const std::unordered_map<std::string, std::string>& config) {
std::vector<OwnedBuffer> host_owned_buffers;
std::vector<ttnn::Shape> host_owned_shapes;
std::vector<ttnn::TensorSpec> host_owned_specs;
host_owned_buffers.reserve(buffers.size());
host_owned_shapes.reserve(buffers.size());
host_owned_specs.reserve(buffers.size());
if (buffers.empty()) {
throw std::runtime_error("Cannot create a host buffer from an empty vector of xtensors!");
}
Expand All @@ -150,14 +150,15 @@ template <class T, DataType TensorType>
host_owned_buffers.push_back(owned_buffer);
}

host_owned_shapes.push_back(shape);
host_owned_specs.push_back(
TensorSpec(shape, TensorLayout(TensorType, PageConfig(Layout::ROW_MAJOR), MemoryConfig{})));
}
auto distributed_tensor_config = get_distributed_tensor_config(config);
auto storage = tt::tt_metal::MultiDeviceHostStorage(
distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes);
distributed_tensor_config, std::move(host_owned_buffers), host_owned_specs);

// remove possible paddings from the shape (it conflicts with ROW MAJOR)
auto output = Tensor(std::move(storage), host_owned_shapes[0], TensorType, Layout::ROW_MAJOR);
auto output = Tensor(std::move(storage), host_owned_specs[0]);
return output;
}

Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,17 @@ Tensor convert_python_tensors_to_tt_tensors(
/*force_disable_borrow=*/true));
}
std::vector<OwnedBuffer> host_owned_buffers;
std::vector<ttnn::Shape> host_owned_shapes;
std::vector<ttnn::TensorSpec> host_owned_specs;
for (const auto& shard : tt_shards) {
TT_ASSERT(
std::holds_alternative<OwnedStorage>(shard.get_storage()),
"Unexpected type {}",
tt::stl::get_active_type_name_in_variant(shard.get_storage()));
host_owned_buffers.push_back(std::get<OwnedStorage>(shard.get_storage()).buffer);
host_owned_shapes.push_back(shard.shape());
host_owned_specs.push_back(shard.get_tensor_spec());
}
auto distributed_tensor_config = get_distributed_tensor_config(strategy);
auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes};
auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_specs};

auto output = Tensor(std::move(storage), tt_shards.at(0).get_tensor_spec());
output = tt::tt_metal::set_tensor_id(output);
Expand Down
6 changes: 1 addition & 5 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,7 @@ template <DeviceOperationConcept device_operation_t>
typename device_operation_t::tensor_args_t get_shard_tensor_args(std::size_t index, auto device, const typename device_operation_t::tensor_args_t& tensor_args) {
auto get_shard = [device](const auto& tensor) {
auto& storage = std::get<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage());
return Tensor{
DeviceStorage{storage.get_buffer_for_device(device)},
storage.get_tensor_shape_for_device(device),
tensor.get_dtype(),
tensor.get_layout()};
return Tensor{DeviceStorage{storage.get_buffer_for_device(device)}, storage.get_tensor_spec_for_device(device)};
};
return tt::stl::reflection::transform_object_of_type<Tensor>(get_shard, tensor_args);
}
Expand Down
42 changes: 15 additions & 27 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ std::vector<ttnn::Tensor> get_device_tensors(const ttnn::Tensor& tensor) {
auto& host_storage = std::get<tt::tt_metal::MultiDeviceHostStorage>(tensor.get_storage());
const Tile tile = tensor.get_tensor_spec().tile();
for (int i = 0; i < host_storage.num_buffers(); ++i) {
tensors.push_back(Tensor{
OwnedStorage{host_storage.get_buffer(i)},
host_storage.shapes[i],
tensor.get_dtype(),
tensor.get_layout(),
tile});
tensors.push_back(Tensor{OwnedStorage{host_storage.get_buffer(i)}, host_storage.specs[i]});
}
return tensors;
} else if (std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage())) {
Expand Down Expand Up @@ -78,11 +73,11 @@ Tensor aggregate_as_tensor(
StorageType storage_type = reference_shard.storage_type();
Tile tile = reference_shard.get_tensor_spec().tile();
if (storage_type == StorageType::OWNED) {
std::vector<ttnn::Shape> shapes;
std::vector<ttnn::TensorSpec> specs;
std::vector<OwnedBuffer> host_owned_buffers;
for (const auto& shard : tensor_shards) {
host_owned_buffers.push_back(std::get<OwnedStorage>(shard.get_storage()).buffer);
shapes.push_back(shard.get_shape());
specs.push_back(shard.get_tensor_spec());
Tile shard_tile = shard.get_tensor_spec().tile();
if (shard_tile != tile) {
TT_THROW(
Expand All @@ -96,7 +91,7 @@ Tensor aggregate_as_tensor(
shard_tile.get_width());
}
}
auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), shapes};
auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs};
return Tensor(
std::move(storage),
reference_shard.get_legacy_shape(),
Expand All @@ -105,14 +100,14 @@ Tensor aggregate_as_tensor(
tile);
} else {
std::vector<int> ordered_device_ids;
std::unordered_map<int, ttnn::Shape> shapes;
std::unordered_map<int, ttnn::TensorSpec> specs;
std::unordered_map<int, DeviceBuffer> device_buffers;
for (const auto& shard : tensor_shards) {
IDevice* device = std::get<DeviceStorage>(shard.get_storage()).buffer->device();
auto device_id = device->id();
ordered_device_ids.push_back(device_id);
device_buffers.insert({device->id(), std::get<DeviceStorage>(shard.get_storage()).buffer});
shapes.insert({device->id(), shard.get_shape()});
specs.insert({device->id(), shard.get_tensor_spec()});
Tile shard_tile = shard.get_tensor_spec().tile();
if (shard_tile != tile) {
TT_THROW(
Expand All @@ -126,7 +121,7 @@ Tensor aggregate_as_tensor(
shard_tile.get_width());
}
}
auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), shapes};
auto storage = MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs};
return Tensor(
std::move(storage),
reference_shard.get_legacy_shape(),
Expand Down Expand Up @@ -230,10 +225,7 @@ std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_de
for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) {
auto device_id = tensor_storage.ordered_device_ids[i];
tensors[i] = Tensor{
DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)},
tensor_storage.shapes.at(device_id),
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()};
DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, tensor_storage.specs.at(device_id)};
}
return tensors;
} else if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) {
Expand All @@ -243,11 +235,7 @@ std::vector<Tensor> get_tensors_from_multi_device_storage(const Tensor& multi_de
tt::stl::get_active_type_name_in_variant(multi_device_tensor.get_storage()));
const auto& tensor_storage = std::get<MultiDeviceHostStorage>(multi_device_tensor.get_storage());
for (int i = 0; i < tensor_storage.num_buffers(); ++i) {
tensors.push_back(Tensor{
OwnedStorage{tensor_storage.get_buffer(i)},
tensor_storage.shapes[i],
multi_device_tensor.get_dtype(),
multi_device_tensor.get_layout()});
tensors.push_back(Tensor{OwnedStorage{tensor_storage.get_buffer(i)}, tensor_storage.specs[i]});
}
} else {
TT_THROW("get_tensors_from_multi_device_storage only support multi device tensors");
Expand All @@ -263,7 +251,7 @@ Tensor create_multi_device_tensor(

if (storage_type == StorageType::MULTI_DEVICE) {
std::vector<int> ordered_device_ids;
std::unordered_map<int, ttnn::Shape> shapes;
std::unordered_map<int, ttnn::TensorSpec> specs;
std::unordered_map<int, DeviceBuffer> device_buffers;
for (const auto& tensor : tensors) {
TT_ASSERT(
Expand All @@ -274,26 +262,26 @@ Tensor create_multi_device_tensor(
auto device_id = device->id();
ordered_device_ids.push_back(device_id);
device_buffers.insert({device_id, std::get<DeviceStorage>(tensor.get_storage()).buffer});
shapes.insert({device_id, tensor.get_shape()});
specs.insert({device_id, tensor.get_tensor_spec()});
}
return Tensor{
MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, shapes},
MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs},
tensors.at(0).get_legacy_shape(),
tensors.at(0).get_dtype(),
tensors.at(0).get_layout()};
} else if (storage_type == StorageType::MULTI_DEVICE_HOST) {
std::vector<OwnedBuffer> owned_buffers;
std::vector<ttnn::Shape> shapes;
std::vector<ttnn::TensorSpec> specs;
for (const auto& tensor : tensors) {
TT_ASSERT(
std::holds_alternative<OwnedStorage>(tensor.get_storage()),
"Unexpected type {}",
tt::stl::get_active_type_name_in_variant(tensor.get_storage()));
owned_buffers.push_back(std::get<OwnedStorage>(tensor.get_storage()).buffer);
shapes.push_back(tensor.get_shape());
specs.push_back(tensor.get_tensor_spec());
}
return Tensor{
MultiDeviceHostStorage{strategy, owned_buffers, shapes},
MultiDeviceHostStorage{strategy, owned_buffers, specs},
tensors.at(0).get_legacy_shape(),
tensors.at(0).get_dtype(),
tensors.at(0).get_layout()};
Expand Down
33 changes: 24 additions & 9 deletions ttnn/cpp/ttnn/operations/experimental/reshape/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Tensor tensor_reshape(
const Tensor& input_tensor, const ttnn::SimpleShape& new_logical_shape, const ttnn::SimpleShape& new_padded_shape) {
ZoneScoped;
GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_logical_shape, new_padded_shape);
const auto tile = input_tensor.get_tensor_spec().tile();
auto new_spec = ttnn::TensorSpec(
new_logical_shape,
TensorLayout::fromPaddedShape(
Expand All @@ -27,25 +26,41 @@ Tensor tensor_reshape(
new_logical_shape,
new_padded_shape));
auto output = std::visit(
[&input_tensor, &new_spec, &new_logical_shape, &new_padded_shape, &tile](auto&& storage) -> Tensor {
[&input_tensor, &new_spec, &new_logical_shape, &new_padded_shape](auto&& storage) -> Tensor {
using T = std::decay_t<decltype(storage)>;
const auto& tensor = input_tensor;
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
auto updated_storage = std::get<T>(tensor.get_storage());
auto shape = ttnn::Shape{tt::tt_metal::LegacyShape{new_logical_shape.view(), new_padded_shape.view()}};
for (int i = 0; i < updated_storage.shapes.size(); i++) {
updated_storage.shapes[i] = shape;
for (int i = 0; i < updated_storage.specs.size(); i++) {
const auto& prev_spec = updated_storage.specs[i];
TensorSpec spec(
new_logical_shape,
TensorLayout::fromPaddedShape(
prev_spec.data_type(),
prev_spec.page_config(),
prev_spec.memory_config(),
new_logical_shape,
new_padded_shape));
updated_storage.specs[i] = spec;
}
return Tensor(updated_storage, new_spec);
}
if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
MultiDeviceStorage updated_storage = std::get<T>(tensor.get_storage());
std::unordered_map<int, ttnn::Shape> new_shapes;
auto shape = ttnn::Shape{tt::tt_metal::LegacyShape{new_logical_shape.view(), new_padded_shape.view()}};
std::unordered_map<int, ttnn::TensorSpec> new_specs;
for (auto device_id : updated_storage.ordered_device_ids) {
new_shapes.insert({device_id, shape});
const auto& prev_spec = updated_storage.specs.at(device_id);
TensorSpec spec(
new_logical_shape,
TensorLayout::fromPaddedShape(
prev_spec.data_type(),
prev_spec.page_config(),
prev_spec.memory_config(),
new_logical_shape,
new_padded_shape));
new_specs.insert({device_id, spec});
}
updated_storage.shapes = new_shapes;
updated_storage.specs = new_specs;
return Tensor(updated_storage, new_spec);
}
if constexpr (std::is_same_v<T, DeviceStorage>) {
Expand Down
Loading

0 comments on commit 2e8c7e7

Please sign in to comment.