Skip to content

Commit

Permalink
Uplift third_party/tt-metal to 57ba436ec4366d9129df6a53b2d9e1e828ef03…
Browse files Browse the repository at this point in the history
…56 2025-02-25 (#2256)

This PR uplifts the third_party/tt-metal to the
57ba436ec4366d9129df6a53b2d9e1e828ef0356

 - Remove references to SimpleMeshShape after metal commit f3bb74d
- Temporary hacks to support int32 for binary ops in runtime
(binary.cpp, binary_composite.cpp, utils.cpp, utils.h changes) because
tt-metal disallow uint32 for some binary ops.

---------

Co-authored-by: kmabeeTT <118925087+kmabeeTT@users.noreply.github.com>
Co-authored-by: brataTT <achoudhury@tenstorrent.com>
Co-authored-by: Jackson Nie <jnie@tenstorrent.com>
  • Loading branch information
4 people authored Feb 26, 2025
1 parent ceff7fe commit 8f932ca
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 48 deletions.
9 changes: 4 additions & 5 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,12 @@ std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc(
tt::runtime::common::getDispatchCoreType(dispatchCoreType);
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::tt::tt_metal::distributed::MeshShape meshShape = {1, numDevices};
::tt::tt_metal::distributed::MeshShape meshShape{
1, static_cast<uint32_t>(numDevices)};
std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice =
::tt::tt_metal::distributed::MeshDevice::create(
::tt::tt_metal::distributed::MeshDeviceConfig{
.mesh_shape =
::tt::tt_metal::distributed::SimpleMeshShape(meshShape),
.offset = {}},
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = meshShape,
.offset = {}},
DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, type);
CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
LOG_INFO("Grid size = { ", logical_grid_size.x, ", ", logical_grid_size.y,
Expand Down
8 changes: 4 additions & 4 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
::tt::tt_metal::DispatchCoreType type =
tt::runtime::common::getDispatchCoreType(dispatchCoreType);

::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
::tt::tt_metal::distributed::MeshShape grid{
1, static_cast<uint32_t>(deviceIds.size())};
size_t l1SmallSizeValue = l1SmallSize.value_or(DEFAULT_L1_SMALL_SIZE);
std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice =
::tt::tt_metal::distributed::MeshDevice::create(
::tt::tt_metal::distributed::MeshDeviceConfig{
.mesh_shape = ::tt::tt_metal::distributed::SimpleMeshShape(grid),
.offset = {}},
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid,
.offset = {}},
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, type);

CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
Expand Down
32 changes: 16 additions & 16 deletions runtime/lib/ttnn/operations/context/get_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@

namespace tt::runtime::ttnn::operations::context {

using ::tt::tt_metal::distributed::MeshOffset;
using ::tt::tt_metal::distributed::MeshCoordinate;
using ::tt::tt_metal::distributed::MeshShape;

static MeshOffset
calculateMeshOffset(const ::ttnn::MeshDevice &parentMesh,
const std::unordered_set<uint32_t> &desiredDeviceIds,
const ::tt::target::Dim2d *subMeshShape) {
for (size_t row = 0; row < parentMesh.num_rows(); row++) {
for (size_t col = 0; col < parentMesh.num_cols(); col++) {
const ::ttnn::IDevice *currDevice = parentMesh.get_device(row, col);
static MeshCoordinate
calculateMeshCoordinate(const ::ttnn::MeshDevice &parentMesh,
const std::unordered_set<uint32_t> &desiredDeviceIds,
const ::tt::target::Dim2d *subMeshShape) {
for (uint32_t row = 0; row < parentMesh.shape()[0]; row++) {
for (uint32_t col = 0; col < parentMesh.shape()[1]; col++) {
const ::ttnn::IDevice *currDevice = parentMesh.get_device({row, col});
if (desiredDeviceIds.contains(currDevice->id())) {
return MeshOffset(row, col);
return MeshCoordinate(row, col);
}
}
}
Expand All @@ -34,9 +34,9 @@ createSubMesh(::ttnn::MeshDevice &parentMesh,
const ::tt::target::Dim2d *subMeshShape) {
// Carve out a submesh from the parentMesh
MeshShape meshShape(subMeshShape->y(), subMeshShape->x());
MeshOffset offset =
calculateMeshOffset(parentMesh, desiredDeviceIds, subMeshShape);
return parentMesh.create_submesh(meshShape, offset);
MeshCoordinate coordinate =
calculateMeshCoordinate(parentMesh, desiredDeviceIds, subMeshShape);
return parentMesh.create_submesh(meshShape, coordinate);
}

void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) {
Expand All @@ -50,11 +50,11 @@ void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) {

// Re-map mesh if subMeshShape cannot be a submesh of current shape
MeshShape meshShape = meshDevice.shape();
if (subMeshShape->y() > static_cast<int32_t>(meshShape.num_rows) ||
subMeshShape->x() > static_cast<int32_t>(meshShape.num_cols)) {
if (subMeshShape->y() > static_cast<int32_t>(meshShape[0]) ||
subMeshShape->x() > static_cast<int32_t>(meshShape[1])) {
meshDevice.reshape(MeshShape(subMeshShape->y(), subMeshShape->x()));
LOG_INFO("remapped mesh device shape [", meshDevice.num_rows(), ", ",
meshDevice.num_cols(), "]");
LOG_INFO("remapped mesh device shape [", meshDevice.shape()[0], ", ",
meshDevice.shape()[1], "]");
}
std::shared_ptr<::ttnn::MeshDevice> subMesh =
createSubMesh(meshDevice, desiredDeviceIds, subMeshShape);
Expand Down
18 changes: 14 additions & 4 deletions runtime/lib/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ static void runEltwiseBinaryOp(
std::optional<::ttnn::operations::unary::FusedActivations>,
std::optional<::ttnn::operations::unary::UnaryWithParam>)> &ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs);
::ttnn::Tensor lhs, rhs;
getEltwiseBinaryOpInputTensors(op, tensorPool, lhs, rhs);

// TODO (#2272): Support for int32 is added in #2272
// However to_layout ops are not cannonicalized properly, blocking #2272
// This is a hack to unblock metal uplifts for now until #2272 is merged
if (lhs.get_dtype() == ::ttnn::DataType::UINT32) {
lhs = ::ttnn::typecast(lhs, ::ttnn::DataType::INT32);
}
if (rhs.get_dtype() == ::ttnn::DataType::UINT32) {
rhs = ::ttnn::typecast(rhs, ::ttnn::DataType::INT32);
}

::ttnn::DataType outputDataType = utils::getDataType(op->out());

Expand All @@ -34,8 +43,9 @@ static void runEltwiseBinaryOp(
outputMemoryConfig.has_value(),
"Memory config must exist for device tensors");

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputDataType, outputMemoryConfig,
::ttnn::Tensor out = ttnnOp(lhs, rhs, outputDataType, outputMemoryConfig,
std::nullopt, std::nullopt, std::nullopt);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

Expand Down
27 changes: 23 additions & 4 deletions runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@ static void runEltwiseBinaryCompositeOp(
::ttnn::Tensor(const ::ttnn::Tensor &, const ::ttnn::Tensor &,
const std::optional<::ttnn::MemoryConfig> &)> &ttnnOp) {

::ttnn::Tensor *lhs = nullptr;
::ttnn::Tensor *rhs = nullptr;
getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs);
::ttnn::Tensor lhs, rhs;
getEltwiseBinaryOpInputTensors(op, tensorPool, lhs, rhs);

// TODO (#2272): Support for int32 is added in #2272
// However to_layout ops are not cannonicalized properly, blocking #2272
// This is a hack to unblock metal uplifts for now until #2272 is merged
if (lhs.get_dtype() == ::ttnn::DataType::UINT32) {
lhs = ::ttnn::typecast(lhs, ::ttnn::DataType::INT32);
}
if (rhs.get_dtype() == ::ttnn::DataType::UINT32) {
rhs = ::ttnn::typecast(rhs, ::ttnn::DataType::INT32);
}

std::optional<::ttnn::MemoryConfig> outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
Expand All @@ -27,7 +36,17 @@ static void runEltwiseBinaryCompositeOp(
outputMemoryConfig.has_value(),
"Memory config must exist for device tensors");

::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig);
::ttnn::Tensor out = ttnnOp(lhs, rhs, outputMemoryConfig);

// TODO (#2272): Support for int32 is added in #2272
// However to_layout ops are not cannonicalized properly, blocking #2272
// This is a hack to unblock metal uplifts for now until #2272 is merged
::ttnn::DataType outputDataType = utils::getDataType(op->out());
if (out.get_dtype() == ::ttnn::DataType::INT32 &&
outputDataType == ::ttnn::DataType::UINT32) {
out = ::ttnn::typecast(out, ::ttnn::DataType::UINT32);
}

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@
namespace tt::runtime::ttnn::operations::binary {

bool shouldSwapBinaryOperands(const ::tt::target::ttnn::EltwiseOp *op,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) {
const ::ttnn::Tensor &lhs, ::ttnn::Tensor &rhs) {
// For scatter, we expect the left-hand side operator to be lesser or equal in
// volume to the right hand side, so we omit the swap.
return (op->type() != ::tt::target::ttnn::EltwiseOpType::Scatter &&
workaround::Env::get().swapBinaryOperands &&
(*lhs)->volume() < (*rhs)->volume());
lhs.volume() < rhs.volume());
}

void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs,
::ttnn::Tensor **rhs) {
::ttnn::Tensor &lhs, ::ttnn::Tensor &rhs) {

LOG_ASSERT(op->ins()->size() == 2, "Expected 2 inputs");
*lhs = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*rhs = &(tensorPool.at(op->ins()->Get(1)->global_id()));
DEBUG_ASSERT((*lhs)->is_allocated());
DEBUG_ASSERT((*rhs)->is_allocated());
lhs = tensorPool.at(op->ins()->Get(0)->global_id());
rhs = tensorPool.at(op->ins()->Get(1)->global_id());
DEBUG_ASSERT(lhs.is_allocated());
DEBUG_ASSERT(rhs.is_allocated());

// Switch the order of operands if the second operand requires broadcast
// TODO(bug #1124): We're currently swapping the operands for binary ops
// in runtime if the lhs operand is smaller (and requires broadcast onto the
// rhs operand). We should add this check in the compiler.
if (shouldSwapBinaryOperands(op, lhs, rhs)) {
std::swap(*lhs, *rhs);
std::swap(lhs, rhs);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::binary {

void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs);
::ttnn::Tensor &lhs, ::ttnn::Tensor &rhs);

} // namespace tt::runtime::ttnn::operations::binary

Expand Down
8 changes: 4 additions & 4 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,12 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
tt::runtime::common::getDispatchCoreType(dispatchCoreType);

LOG_ASSERT(deviceIds.size(), "No devices specified");
::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
::tt::tt_metal::distributed::MeshShape grid{
1, static_cast<uint32_t>(deviceIds.size())};
size_t l1SmallSizeValue = l1SmallSize.value_or(kL1SmallSize);
std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create(
::tt::tt_metal::distributed::MeshDeviceConfig{
.mesh_shape = ::tt::tt_metal::distributed::SimpleMeshShape(grid),
.offset = {}},
::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid,
.offset = {}},
l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, type);

CoreCoord logical_grid_size = meshDevice->compute_with_storage_grid_size();
Expand Down
2 changes: 1 addition & 1 deletion third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include(ExternalProject)

set(TT_METAL_VERSION "99e8f45516093967fd56ff3de98efa47868a3a02")
set(TT_METAL_VERSION "57ba436ec4366d9129df6a53b2d9e1e828ef0356")

if ("$ENV{ARCH_NAME}" STREQUAL "grayskull")
set(ARCH_NAME "grayskull")
Expand Down

0 comments on commit 8f932ca

Please sign in to comment.